diff --git a/cli/azd/CHANGELOG.md b/cli/azd/CHANGELOG.md index 3a75ae7cdb3..0de0cbc0f8c 100644 --- a/cli/azd/CHANGELOG.md +++ b/cli/azd/CHANGELOG.md @@ -4,6 +4,13 @@ ### Features Added +- Add `ConfigHelper` for typed, ergonomic access to azd user and environment configuration through gRPC services, with validation support, shallow/deep merge, and structured error types (`ConfigError`). +- Add `Pager[T]` generic pagination helper with SSRF-safe nextLink validation, `Collect` with `MaxPages`/`MaxItems` bounds, and `Truncated()` detection for callers. +- Add `ResilientClient` hardening: exponential backoff with jitter, upfront body seekability validation, and `Retry-After` header cap at 120 s. +- Add `SSRFGuard` standalone SSRF protection with metadata endpoint blocking, private network blocking, HTTPS enforcement, DNS fail-closed, IPv6 embedding extraction, and allowlist bypass. +- Add atomic file operations (`WriteFileAtomic`, `CopyFileAtomic`, `BackupFile`, `EnsureDir`) with crash-safe write-temp-rename pattern. +- Add runtime process utilities for cross-platform process management, tool discovery, and shell execution helpers. + ### Breaking Changes ### Bugs Fixed diff --git a/cli/azd/pkg/azdext/atomicfile.go b/cli/azd/pkg/azdext/atomicfile.go new file mode 100644 index 00000000000..7acac0cbc07 --- /dev/null +++ b/cli/azd/pkg/azdext/atomicfile.go @@ -0,0 +1,219 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/azure/azure-dev/cli/azd/pkg/osutil" +) + +// --------------------------------------------------------------------------- +// P3-2: Atomic file operations +// --------------------------------------------------------------------------- + +// WriteFileAtomic writes data to the named file atomically. It writes to a +// temporary file in the same directory as path and renames it into place. This +// ensures that readers never see a partially-written file and that the +// operation is crash-safe on filesystems that support atomic rename (ext4, +// APFS, NTFS). +// +// Platform behavior: +// - Unix: os.Rename is atomic within the same filesystem. +// - Windows: os.Rename replaces the target if it exists (Go 1.16+). On +// older Go runtimes or cross-device moves, the operation may fail. +// WriteFileAtomic always places the temp file in the same directory to +// avoid cross-device issues. +// +// The file is created with the specified permissions. If the target already +// exists its permissions are preserved unless perm is explicitly non-zero. +// +// Returns an error if the directory does not exist, the temp file cannot be +// created, data cannot be written, or the rename fails. +func WriteFileAtomic(path string, data []byte, perm os.FileMode) error { + dir := filepath.Dir(path) + + // Validate that the target directory exists. + if _, err := os.Stat(dir); err != nil { + return fmt.Errorf("azdext.WriteFileAtomic: target directory: %w", err) + } + + // If perm is zero and the target exists, preserve existing permissions. + if perm == 0 { + if fi, err := os.Stat(path); err == nil { + perm = fi.Mode().Perm() + } else { + perm = 0o644 + } + } + + // Create temp file in the same directory (same filesystem = atomic rename). + tmp, err := os.CreateTemp(dir, ".azdext-atomic-*") + if err != nil { + return fmt.Errorf("azdext.WriteFileAtomic: create temp: %w", err) + } + tmpPath := tmp.Name() + + // Ensure cleanup on any failure path. + success := false + defer func() { + if !success { + _ = tmp.Close() + _ = os.Remove(tmpPath) + } + }() + + // Write data and sync to disk. + if _, err := tmp.Write(data); err != nil { + return fmt.Errorf("azdext.WriteFileAtomic: write: %w", err) + } + if err := tmp.Sync(); err != nil { + return fmt.Errorf("azdext.WriteFileAtomic: sync: %w", err) + } + if err := tmp.Close(); err != nil { + return fmt.Errorf("azdext.WriteFileAtomic: close: %w", err) + } + + // Set permissions on temp file before rename. + if err := os.Chmod(tmpPath, perm); err != nil { + return fmt.Errorf("azdext.WriteFileAtomic: chmod: %w", err) + } + + // Atomic rename into place. + if err := osutil.Rename(context.Background(), tmpPath, path); err != nil { + return fmt.Errorf("azdext.WriteFileAtomic: rename: %w", err) + } + + success = true + return nil +} + +// CopyFileAtomic copies src to dst atomically using the write-temp-rename +// pattern. The destination file is never in a partially-written state. +// The copy is streamed through a fixed-size buffer (no unbounded memory +// allocation regardless of source file size). +// +// Platform behavior: see [WriteFileAtomic]. +// +// If perm is zero, the source file's permissions are used. +func CopyFileAtomic(src, dst string, perm os.FileMode) error { + srcFile, err := os.Open(src) + if err != nil { + return fmt.Errorf("azdext.CopyFileAtomic: open source: %w", err) + } + defer srcFile.Close() + + // Determine permissions. + if perm == 0 { + if fi, err := srcFile.Stat(); err == nil { + perm = fi.Mode().Perm() + } else { + perm = 0o644 + } + } + + dir := filepath.Dir(dst) + + // Validate that the target directory exists. + if _, err := os.Stat(dir); err != nil { + return fmt.Errorf("azdext.CopyFileAtomic: target directory: %w", err) + } + + // Create temp file in the same directory (same filesystem = atomic rename). + tmp, err := os.CreateTemp(dir, ".azdext-atomic-*") + if err != nil { + return fmt.Errorf("azdext.CopyFileAtomic: create temp: %w", err) + } + tmpPath := tmp.Name() + + // Ensure cleanup on any failure path. + success := false + defer func() { + if !success { + _ = tmp.Close() + _ = os.Remove(tmpPath) + } + }() + + // Stream copy with fixed-size buffer — no unbounded memory allocation. + if _, err := io.Copy(tmp, srcFile); err != nil { + return fmt.Errorf("azdext.CopyFileAtomic: copy: %w", err) + } + + if err := tmp.Sync(); err != nil { + return fmt.Errorf("azdext.CopyFileAtomic: sync: %w", err) + } + if err := tmp.Close(); err != nil { + return fmt.Errorf("azdext.CopyFileAtomic: close: %w", err) + } + + // Set permissions on temp file before rename. + if err := os.Chmod(tmpPath, perm); err != nil { + return fmt.Errorf("azdext.CopyFileAtomic: chmod: %w", err) + } + + // Atomic rename into place. + if err := osutil.Rename(context.Background(), tmpPath, dst); err != nil { + return fmt.Errorf("azdext.CopyFileAtomic: rename: %w", err) + } + + success = true + return nil +} + +// BackupFile creates a backup copy of path at path+suffix using atomic copy. +// If the source file does not exist, it returns nil (no backup needed). +// +// The default suffix is ".bak" if suffix is empty. +// +// Returns the backup path on success, or an error if the copy fails. +func BackupFile(path, suffix string) (string, error) { + if suffix == "" { + suffix = ".bak" + } + + if _, err := os.Stat(path); os.IsNotExist(err) { + return "", nil // Nothing to back up. + } + + backupPath := path + suffix + if err := CopyFileAtomic(path, backupPath, 0); err != nil { + return "", fmt.Errorf("azdext.BackupFile: %w", err) + } + + return backupPath, nil +} + +// EnsureDir creates directory dir and any necessary parents with the given +// permissions. If the directory already exists, EnsureDir is a no-op and +// returns nil. +// +// This is a convenience wrapper around [os.MkdirAll] with an explicit error +// prefix for diagnostics. +// +// Security: EnsureDir cleans the path via [filepath.Clean] and rejects paths +// containing parent-directory traversal ("..") to prevent creating directories +// outside the caller's intended scope. For untrusted input, callers should +// additionally use [MCPSecurityPolicy.CheckPath] for base-directory validation. +func EnsureDir(dir string, perm os.FileMode) error { + if perm == 0 { + perm = 0o755 + } + + // Reject paths containing parent traversal sequences. + cleaned := filepath.Clean(dir) + if strings.Contains(cleaned, "..") { + return fmt.Errorf("azdext.EnsureDir: path traversal detected in %q", dir) + } + + if err := os.MkdirAll(cleaned, perm); err != nil { + return fmt.Errorf("azdext.EnsureDir: %w", err) + } + return nil +} diff --git a/cli/azd/pkg/azdext/atomicfile_test.go b/cli/azd/pkg/azdext/atomicfile_test.go new file mode 100644 index 00000000000..e2c2394b732 --- /dev/null +++ b/cli/azd/pkg/azdext/atomicfile_test.go @@ -0,0 +1,310 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "os" + "path/filepath" + "testing" +) + +// --------------------------------------------------------------------------- +// WriteFileAtomic +// --------------------------------------------------------------------------- + +func TestWriteFileAtomic_NewFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.txt") + data := []byte("hello, atomic world") + + if err := WriteFileAtomic(path, data, 0o644); err != nil { + t.Fatalf("WriteFileAtomic() error: %v", err) + } + + // Verify content. + got, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile() error: %v", err) + } + if string(got) != string(data) { + t.Errorf("WriteFileAtomic() content = %q, want %q", got, data) + } + + // Verify permissions. + fi, err := os.Stat(path) + if err != nil { + t.Fatalf("Stat() error: %v", err) + } + // On Windows, permission bits are limited. Just verify the file exists. + if fi.Size() != int64(len(data)) { + t.Errorf("WriteFileAtomic() file size = %d, want %d", fi.Size(), len(data)) + } +} + +func TestWriteFileAtomic_Overwrite(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.txt") + + // Write initial content. + if err := os.WriteFile(path, []byte("old"), 0o600); err != nil { + t.Fatalf("WriteFile() error: %v", err) + } + + // Overwrite atomically. + newData := []byte("new content here") + if err := WriteFileAtomic(path, newData, 0o644); err != nil { + t.Fatalf("WriteFileAtomic() error: %v", err) + } + + got, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile() error: %v", err) + } + if string(got) != string(newData) { + t.Errorf("WriteFileAtomic() overwrite content = %q, want %q", got, newData) + } +} + +func TestWriteFileAtomic_EmptyData(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "empty.txt") + + if err := WriteFileAtomic(path, []byte{}, 0o644); err != nil { + t.Fatalf("WriteFileAtomic(empty) error: %v", err) + } + + fi, err := os.Stat(path) + if err != nil { + t.Fatalf("Stat() error: %v", err) + } + if fi.Size() != 0 { + t.Errorf("WriteFileAtomic(empty) size = %d, want 0", fi.Size()) + } +} + +func TestWriteFileAtomic_MissingDirectory(t *testing.T) { + path := filepath.Join(t.TempDir(), "nonexistent", "file.txt") + err := WriteFileAtomic(path, []byte("data"), 0o644) + if err == nil { + t.Error("WriteFileAtomic() with missing directory = nil, want error") + } +} + +func TestWriteFileAtomic_PreservesPermissions(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "perm.txt") + + // Write with specific permissions. + if err := WriteFileAtomic(path, []byte("first"), 0o600); err != nil { + t.Fatalf("WriteFileAtomic(first) error: %v", err) + } + + // Overwrite with perm=0 to preserve existing. + if err := WriteFileAtomic(path, []byte("second"), 0); err != nil { + t.Fatalf("WriteFileAtomic(second) error: %v", err) + } + + got, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile() error: %v", err) + } + if string(got) != "second" { + t.Errorf("content = %q, want %q", got, "second") + } +} + +func TestWriteFileAtomic_NoTempFileLeftOnFailure(t *testing.T) { + dir := t.TempDir() + + // Write a valid file first. + path := filepath.Join(dir, "test.txt") + if err := WriteFileAtomic(path, []byte("ok"), 0o644); err != nil { + t.Fatalf("WriteFileAtomic() error: %v", err) + } + + // List directory to establish baseline. + before, _ := os.ReadDir(dir) + beforeCount := len(before) + + // Attempt write to a non-existent sub-directory (should fail). + badPath := filepath.Join(dir, "nodir", "bad.txt") + _ = WriteFileAtomic(badPath, []byte("fail"), 0o644) + + // Verify no temp files were left behind. + after, _ := os.ReadDir(dir) + if len(after) != beforeCount { + t.Errorf("temp file leak: before=%d entries, after=%d entries", beforeCount, len(after)) + } +} + +// --------------------------------------------------------------------------- +// CopyFileAtomic +// --------------------------------------------------------------------------- + +func TestCopyFileAtomic(t *testing.T) { + dir := t.TempDir() + src := filepath.Join(dir, "source.txt") + dst := filepath.Join(dir, "dest.txt") + data := []byte("copy me atomically") + + if err := os.WriteFile(src, data, 0o600); err != nil { + t.Fatalf("WriteFile(src) error: %v", err) + } + + if err := CopyFileAtomic(src, dst, 0); err != nil { + t.Fatalf("CopyFileAtomic() error: %v", err) + } + + got, err := os.ReadFile(dst) + if err != nil { + t.Fatalf("ReadFile(dst) error: %v", err) + } + if string(got) != string(data) { + t.Errorf("CopyFileAtomic() content = %q, want %q", got, data) + } +} + +func TestCopyFileAtomic_SourceNotFound(t *testing.T) { + dir := t.TempDir() + err := CopyFileAtomic(filepath.Join(dir, "missing.txt"), filepath.Join(dir, "dst.txt"), 0) + if err == nil { + t.Error("CopyFileAtomic() with missing source = nil, want error") + } +} + +// --------------------------------------------------------------------------- +// BackupFile +// --------------------------------------------------------------------------- + +func TestBackupFile_CreatesBackup(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "config.yaml") + data := []byte("config: value") + + if err := os.WriteFile(path, data, 0o600); err != nil { + t.Fatalf("WriteFile() error: %v", err) + } + + backupPath, err := BackupFile(path, ".bak") + if err != nil { + t.Fatalf("BackupFile() error: %v", err) + } + if backupPath != path+".bak" { + t.Errorf("BackupFile() path = %q, want %q", backupPath, path+".bak") + } + + got, err := os.ReadFile(backupPath) + if err != nil { + t.Fatalf("ReadFile(backup) error: %v", err) + } + if string(got) != string(data) { + t.Errorf("BackupFile() content = %q, want %q", got, data) + } +} + +func TestBackupFile_DefaultSuffix(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "data.json") + if err := os.WriteFile(path, []byte("{}"), 0o600); err != nil { + t.Fatalf("WriteFile() error: %v", err) + } + + backupPath, err := BackupFile(path, "") + if err != nil { + t.Fatalf("BackupFile() error: %v", err) + } + if backupPath != path+".bak" { + t.Errorf("BackupFile() default suffix: path = %q, want %q", backupPath, path+".bak") + } +} + +func TestBackupFile_SourceNotFound(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "nonexistent.txt") + + backupPath, err := BackupFile(path, ".bak") + if err != nil { + t.Fatalf("BackupFile() error: %v", err) + } + if backupPath != "" { + t.Errorf("BackupFile() for nonexistent source: path = %q, want empty", backupPath) + } +} + +// --------------------------------------------------------------------------- +// EnsureDir +// --------------------------------------------------------------------------- + +func TestEnsureDir_CreatesNew(t *testing.T) { + dir := filepath.Join(t.TempDir(), "a", "b", "c") + + if err := EnsureDir(dir, 0o755); err != nil { + t.Fatalf("EnsureDir() error: %v", err) + } + + fi, err := os.Stat(dir) + if err != nil { + t.Fatalf("Stat() error: %v", err) + } + if !fi.IsDir() { + t.Error("EnsureDir() did not create a directory") + } +} + +func TestEnsureDir_ExistingIsNoOp(t *testing.T) { + dir := t.TempDir() + if err := EnsureDir(dir, 0o755); err != nil { + t.Fatalf("EnsureDir() on existing dir error: %v", err) + } +} + +func TestEnsureDir_DefaultPermissions(t *testing.T) { + dir := filepath.Join(t.TempDir(), "default-perm") + if err := EnsureDir(dir, 0); err != nil { + t.Fatalf("EnsureDir(perm=0) error: %v", err) + } + + fi, err := os.Stat(dir) + if err != nil { + t.Fatalf("Stat() error: %v", err) + } + if !fi.IsDir() { + t.Error("EnsureDir(perm=0) did not create a directory") + } +} + +func TestCopyFileAtomic_LargeFileStreaming(t *testing.T) { + dir := t.TempDir() + src := filepath.Join(dir, "large.bin") + dst := filepath.Join(dir, "large-copy.bin") + + // Create a source file larger than a typical io.Copy buffer (32 KB). + // This verifies the streaming path is exercised. + const size = 256 * 1024 // 256 KB + data := make([]byte, size) + for i := range data { + data[i] = byte(i % 251) // deterministic non-zero pattern + } + if err := os.WriteFile(src, data, 0o644); err != nil { + t.Fatalf("WriteFile(src) error: %v", err) + } + + if err := CopyFileAtomic(src, dst, 0); err != nil { + t.Fatalf("CopyFileAtomic() error: %v", err) + } + + got, err := os.ReadFile(dst) + if err != nil { + t.Fatalf("ReadFile(dst) error: %v", err) + } + if len(got) != size { + t.Errorf("CopyFileAtomic() size = %d, want %d", len(got), size) + } + // Spot-check content integrity. + for _, idx := range []int{0, 1, size / 2, size - 1} { + if got[idx] != data[idx] { + t.Errorf("content mismatch at byte %d: got %d, want %d", idx, got[idx], data[idx]) + } + } +} diff --git a/cli/azd/pkg/azdext/config_helper.go b/cli/azd/pkg/azdext/config_helper.go new file mode 100644 index 00000000000..60917274fb3 --- /dev/null +++ b/cli/azd/pkg/azdext/config_helper.go @@ -0,0 +1,442 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "regexp" + "strings" +) + +// ConfigHelper provides typed, ergonomic access to azd configuration through +// the gRPC UserConfig and Environment services. It eliminates the boilerplate +// of raw gRPC calls and JSON marshaling that extension authors otherwise need. +// +// Configuration sources (in merge priority, lowest to highest): +// 1. User config (global azd config) — via UserConfigService +// 2. Environment config (per-env) — via EnvironmentService +// +// Usage: +// +// ch := azdext.NewConfigHelper(client) +// port, err := ch.GetUserString(ctx, "extensions.myext.port") +// var cfg MyConfig +// err = ch.GetUserJSON(ctx, "extensions.myext", &cfg) +type ConfigHelper struct { + client *AzdClient +} + +// NewConfigHelper creates a [ConfigHelper] for the given AZD client. +func NewConfigHelper(client *AzdClient) (*ConfigHelper, error) { + if client == nil { + return nil, errors.New("azdext.NewConfigHelper: client must not be nil") + } + + return &ConfigHelper{client: client}, nil +} + +// --- User Config (global) --- + +// GetUserString retrieves a string value from the global user config at the +// given dot-separated path. Returns ("", false, nil) when the path does not +// exist, and ("", false, err) on gRPC errors. +func (ch *ConfigHelper) GetUserString(ctx context.Context, path string) (string, bool, error) { + if err := validatePath(path); err != nil { + return "", false, err + } + + resp, err := ch.client.UserConfig().GetString(ctx, &GetUserConfigStringRequest{Path: path}) + if err != nil { + return "", false, fmt.Errorf("azdext.ConfigHelper.GetUserString: gRPC call failed for path %q: %w", path, err) + } + + return resp.GetValue(), resp.GetFound(), nil +} + +// GetUserJSON retrieves a value from the global user config and unmarshals it +// into out. Returns (false, nil) when the path does not exist. +func (ch *ConfigHelper) GetUserJSON(ctx context.Context, path string, out any) (bool, error) { + if err := validatePath(path); err != nil { + return false, err + } + + if out == nil { + return false, errors.New("azdext.ConfigHelper.GetUserJSON: out must not be nil") + } + + resp, err := ch.client.UserConfig().Get(ctx, &GetUserConfigRequest{Path: path}) + if err != nil { + return false, fmt.Errorf("azdext.ConfigHelper.GetUserJSON: gRPC call failed for path %q: %w", path, err) + } + + if !resp.GetFound() { + return false, nil + } + + data := resp.GetValue() + if len(data) == 0 { + return false, nil + } + + if err := json.Unmarshal(data, out); err != nil { + return true, &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("failed to unmarshal config at path %q: %w", path, err), + } + } + + return true, nil +} + +// SetUserJSON marshals value as JSON and writes it to the global user config +// at the given path. +func (ch *ConfigHelper) SetUserJSON(ctx context.Context, path string, value any) error { + if err := validatePath(path); err != nil { + return err + } + + if value == nil { + return errors.New("azdext.ConfigHelper.SetUserJSON: value must not be nil") + } + + data, err := json.Marshal(value) + if err != nil { + return &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("failed to marshal value for path %q: %w", path, err), + } + } + + _, err = ch.client.UserConfig().Set(ctx, &SetUserConfigRequest{ + Path: path, + Value: data, + }) + if err != nil { + return fmt.Errorf("azdext.ConfigHelper.SetUserJSON: gRPC call failed for path %q: %w", path, err) + } + + return nil +} + +// UnsetUser removes a value from the global user config. +func (ch *ConfigHelper) UnsetUser(ctx context.Context, path string) error { + if err := validatePath(path); err != nil { + return err + } + + _, err := ch.client.UserConfig().Unset(ctx, &UnsetUserConfigRequest{Path: path}) + if err != nil { + return fmt.Errorf("azdext.ConfigHelper.UnsetUser: gRPC call failed for path %q: %w", path, err) + } + + return nil +} + +// --- Environment Config (per-environment) --- + +// GetEnvString retrieves a string config value from the current environment. +// Returns ("", false, nil) when the path does not exist. +func (ch *ConfigHelper) GetEnvString(ctx context.Context, path string) (string, bool, error) { + if err := validatePath(path); err != nil { + return "", false, err + } + + resp, err := ch.client.Environment().GetConfigString(ctx, &GetConfigStringRequest{Path: path}) + if err != nil { + return "", false, fmt.Errorf("azdext.ConfigHelper.GetEnvString: gRPC call failed for path %q: %w", path, err) + } + + return resp.GetValue(), resp.GetFound(), nil +} + +// GetEnvJSON retrieves a value from the current environment's config and +// unmarshals it into out. Returns (false, nil) when the path does not exist. +func (ch *ConfigHelper) GetEnvJSON(ctx context.Context, path string, out any) (bool, error) { + if err := validatePath(path); err != nil { + return false, err + } + + if out == nil { + return false, errors.New("azdext.ConfigHelper.GetEnvJSON: out must not be nil") + } + + resp, err := ch.client.Environment().GetConfig(ctx, &GetConfigRequest{Path: path}) + if err != nil { + return false, fmt.Errorf("azdext.ConfigHelper.GetEnvJSON: gRPC call failed for path %q: %w", path, err) + } + + if !resp.GetFound() { + return false, nil + } + + data := resp.GetValue() + if len(data) == 0 { + return false, nil + } + + if err := json.Unmarshal(data, out); err != nil { + return true, &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("failed to unmarshal env config at path %q: %w", path, err), + } + } + + return true, nil +} + +// SetEnvJSON marshals value as JSON and writes it to the current environment's config. +func (ch *ConfigHelper) SetEnvJSON(ctx context.Context, path string, value any) error { + if err := validatePath(path); err != nil { + return err + } + + if value == nil { + return errors.New("azdext.ConfigHelper.SetEnvJSON: value must not be nil") + } + + data, err := json.Marshal(value) + if err != nil { + return &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("failed to marshal value for env config path %q: %w", path, err), + } + } + + _, err = ch.client.Environment().SetConfig(ctx, &SetConfigRequest{ + Path: path, + Value: data, + }) + if err != nil { + return fmt.Errorf("azdext.ConfigHelper.SetEnvJSON: gRPC call failed for path %q: %w", path, err) + } + + return nil +} + +// UnsetEnv removes a value from the current environment's config. +func (ch *ConfigHelper) UnsetEnv(ctx context.Context, path string) error { + if err := validatePath(path); err != nil { + return err + } + + _, err := ch.client.Environment().UnsetConfig(ctx, &UnsetConfigRequest{Path: path}) + if err != nil { + return fmt.Errorf("azdext.ConfigHelper.UnsetEnv: gRPC call failed for path %q: %w", path, err) + } + + return nil +} + +// --- Merge --- + +// MergeJSON performs a shallow merge of override into base, returning a new map. +// Both inputs must be JSON-compatible maps (map[string]any). Keys in override +// take precedence over keys in base. +// +// This is NOT a deep merge — nested maps are replaced entirely by the override +// value. For predictable extension config behavior, keep config structures flat +// or use explicit path-based Set operations for nested values. +func MergeJSON(base, override map[string]any) map[string]any { + merged := make(map[string]any, len(base)+len(override)) + + for k, v := range base { + merged[k] = v + } + + for k, v := range override { + merged[k] = v + } + + return merged +} + +// deepMergeMaxDepth is the maximum recursion depth for [DeepMergeJSON]. +// This prevents stack overflow from deeply nested or adversarial JSON +// structures. 32 levels is far deeper than any legitimate config hierarchy. +const deepMergeMaxDepth = 32 + +// DeepMergeJSON performs a recursive merge of override into base. +// When both base and override have a map value for the same key, those maps +// are merged recursively. Otherwise the override value replaces the base value. +// +// Recursion is bounded to [deepMergeMaxDepth] levels to prevent stack overflow +// from deeply nested or adversarial inputs. Beyond the limit, the override +// value replaces the base value (merge degrades to shallow at that level). +func DeepMergeJSON(base, override map[string]any) map[string]any { + return deepMergeJSON(base, override, 0) +} + +func deepMergeJSON(base, override map[string]any, depth int) map[string]any { + merged := make(map[string]any, len(base)+len(override)) + + for k, v := range base { + merged[k] = v + } + + for k, v := range override { + baseVal, exists := merged[k] + if !exists { + merged[k] = v + continue + } + + baseMap, baseIsMap := baseVal.(map[string]any) + overMap, overIsMap := v.(map[string]any) + + if baseIsMap && overIsMap && depth < deepMergeMaxDepth { + merged[k] = deepMergeJSON(baseMap, overMap, depth+1) + } else { + merged[k] = v + } + } + + return merged +} + +// --- Validation --- + +// ConfigValidator defines a function that validates a config value. +// It returns nil if valid, or an error describing the validation failure. +type ConfigValidator func(value any) error + +// ValidateConfig unmarshals the raw JSON data and runs all supplied validators. +// Returns the first validation error encountered, wrapped in a [*ConfigError]. +func ValidateConfig(path string, data []byte, validators ...ConfigValidator) error { + if len(data) == 0 { + return &ConfigError{ + Path: path, + Reason: ConfigReasonMissing, + Err: fmt.Errorf("config at path %q is empty", path), + } + } + + var value any + if err := json.Unmarshal(data, &value); err != nil { + return &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("config at path %q is not valid JSON: %w", path, err), + } + } + + for _, v := range validators { + if err := v(value); err != nil { + return &ConfigError{ + Path: path, + Reason: ConfigReasonValidationFailed, + Err: fmt.Errorf("config validation failed at path %q: %w", path, err), + } + } + } + + return nil +} + +// RequiredKeys returns a [ConfigValidator] that checks for the presence of +// the specified keys in a map value. +func RequiredKeys(keys ...string) ConfigValidator { + return func(value any) error { + m, ok := value.(map[string]any) + if !ok { + return fmt.Errorf("expected object, got %T", value) + } + + for _, key := range keys { + if _, exists := m[key]; !exists { + return fmt.Errorf("required key %q is missing", key) + } + } + + return nil + } +} + +// --- Error types --- + +// ConfigReason classifies the cause of a [ConfigError]. +type ConfigReason int + +const ( + // ConfigReasonMissing indicates the config path does not exist or is empty. + ConfigReasonMissing ConfigReason = iota + + // ConfigReasonInvalidFormat indicates the config value is not valid JSON + // or cannot be unmarshaled into the target type. + ConfigReasonInvalidFormat + + // ConfigReasonValidationFailed indicates a validator rejected the config value. + ConfigReasonValidationFailed +) + +// String returns a human-readable label. +func (r ConfigReason) String() string { + switch r { + case ConfigReasonMissing: + return "missing" + case ConfigReasonInvalidFormat: + return "invalid_format" + case ConfigReasonValidationFailed: + return "validation_failed" + default: + return "unknown" + } +} + +// ConfigError is returned by [ConfigHelper] methods on domain-level failures. +type ConfigError struct { + // Path is the config path that was being accessed. + Path string + + // Reason classifies the failure. + Reason ConfigReason + + // Err is the underlying error. + Err error +} + +func (e *ConfigError) Error() string { + return fmt.Sprintf("azdext.ConfigHelper: %s (path=%s): %v", e.Reason, e.Path, e.Err) +} + +func (e *ConfigError) Unwrap() error { + return e.Err +} + +var configSegmentRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]{0,62}$`) + +func validatePath(path string) error { + if path == "" { + return errors.New("azdext.ConfigHelper: config path must not be empty") + } + if strings.HasPrefix(path, ".") || strings.HasSuffix(path, ".") || strings.Contains(path, "..") { + return errors.New( + "azdext.ConfigHelper: config path must not have empty segments " + + "(no leading/trailing dots or consecutive dots)", + ) + } + for _, seg := range strings.Split(path, ".") { + if !configSegmentRe.MatchString(seg) { + return fmt.Errorf( + "azdext.ConfigHelper: config path segment %q must start with alphanumeric "+ + "and contain only [a-zA-Z0-9_-], max 63 chars", + truncateConfigValue(seg, 64), + ) + } + } + return nil +} + +func truncateConfigValue(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} diff --git a/cli/azd/pkg/azdext/config_helper_test.go b/cli/azd/pkg/azdext/config_helper_test.go new file mode 100644 index 00000000000..c68464f16bc --- /dev/null +++ b/cli/azd/pkg/azdext/config_helper_test.go @@ -0,0 +1,967 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "google.golang.org/grpc" +) + +// --- Stub UserConfigService --- + +type stubUserConfigService struct { + getResp *GetUserConfigResponse + getStringResp *GetUserConfigStringResponse + getSectionErr error + getErr error + getStringErr error + setErr error + unsetErr error +} + +func (s *stubUserConfigService) Get( + _ context.Context, _ *GetUserConfigRequest, _ ...grpc.CallOption, +) (*GetUserConfigResponse, error) { + return s.getResp, s.getErr +} + +func (s *stubUserConfigService) GetString( + _ context.Context, _ *GetUserConfigStringRequest, _ ...grpc.CallOption, +) (*GetUserConfigStringResponse, error) { + return s.getStringResp, s.getStringErr +} + +func (s *stubUserConfigService) GetSection( + _ context.Context, _ *GetUserConfigSectionRequest, _ ...grpc.CallOption, +) (*GetUserConfigSectionResponse, error) { + return nil, s.getSectionErr +} + +func (s *stubUserConfigService) Set( + _ context.Context, _ *SetUserConfigRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return &EmptyResponse{}, s.setErr +} + +func (s *stubUserConfigService) Unset( + _ context.Context, _ *UnsetUserConfigRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return &EmptyResponse{}, s.unsetErr +} + +// --- Stub EnvironmentService --- + +type stubEnvironmentService struct { + getConfigResp *GetConfigResponse + getConfigStringResp *GetConfigStringResponse + getConfigErr error + getConfigStringErr error + setConfigErr error + unsetConfigErr error +} + +func (s *stubEnvironmentService) GetCurrent( + _ context.Context, _ *EmptyRequest, _ ...grpc.CallOption, +) (*EnvironmentResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) List( + _ context.Context, _ *EmptyRequest, _ ...grpc.CallOption, +) (*EnvironmentListResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) Get( + _ context.Context, _ *GetEnvironmentRequest, _ ...grpc.CallOption, +) (*EnvironmentResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) Select( + _ context.Context, _ *SelectEnvironmentRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) GetValues( + _ context.Context, _ *GetEnvironmentRequest, _ ...grpc.CallOption, +) (*KeyValueListResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) GetValue( + _ context.Context, _ *GetEnvRequest, _ ...grpc.CallOption, +) (*KeyValueResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) SetValue( + _ context.Context, _ *SetEnvRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) GetConfig( + _ context.Context, _ *GetConfigRequest, _ ...grpc.CallOption, +) (*GetConfigResponse, error) { + return s.getConfigResp, s.getConfigErr +} + +func (s *stubEnvironmentService) GetConfigString( + _ context.Context, _ *GetConfigStringRequest, _ ...grpc.CallOption, +) (*GetConfigStringResponse, error) { + return s.getConfigStringResp, s.getConfigStringErr +} + +func (s *stubEnvironmentService) GetConfigSection( + _ context.Context, _ *GetConfigSectionRequest, _ ...grpc.CallOption, +) (*GetConfigSectionResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) SetConfig( + _ context.Context, _ *SetConfigRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return &EmptyResponse{}, s.setConfigErr +} + +func (s *stubEnvironmentService) UnsetConfig( + _ context.Context, _ *UnsetConfigRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return &EmptyResponse{}, s.unsetConfigErr +} + +// --- NewConfigHelper --- + +func TestNewConfigHelper_NilClient(t *testing.T) { + t.Parallel() + + _, err := NewConfigHelper(nil) + if err == nil { + t.Fatal("expected error for nil client") + } +} + +func TestNewConfigHelper_Success(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, err := NewConfigHelper(client) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if ch == nil { + t.Fatal("expected non-nil ConfigHelper") + } +} + +// --- GetUserString --- + +func TestGetUserString_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + _, _, err := ch.GetUserString(context.Background(), "") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +func TestGetUserString_Found(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getStringResp: &GetUserConfigStringResponse{Value: "8080", Found: true}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + val, found, err := ch.GetUserString(context.Background(), "extensions.myext.port") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !found { + t.Error("expected found = true") + } + + if val != "8080" { + t.Errorf("value = %q, want %q", val, "8080") + } +} + +func TestGetUserString_NotFound(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getStringResp: &GetUserConfigStringResponse{Value: "", Found: false}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + val, found, err := ch.GetUserString(context.Background(), "nonexistent.path") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if found { + t.Error("expected found = false") + } + + if val != "" { + t.Errorf("value = %q, want empty", val) + } +} + +func TestGetUserString_GRPCError(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getStringErr: errors.New("grpc unavailable"), + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + _, _, err := ch.GetUserString(context.Background(), "some.path") + if err == nil { + t.Fatal("expected error for gRPC failure") + } +} + +// --- GetUserJSON --- + +func TestGetUserJSON_Found(t *testing.T) { + t.Parallel() + + type myConfig struct { + Port int `json:"port"` + Host string `json:"host"` + } + + data, _ := json.Marshal(myConfig{Port: 3000, Host: "localhost"}) + stub := &stubUserConfigService{ + getResp: &GetUserConfigResponse{Value: data, Found: true}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg myConfig + found, err := ch.GetUserJSON(context.Background(), "extensions.myext", &cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !found { + t.Error("expected found = true") + } + + if cfg.Port != 3000 { + t.Errorf("Port = %d, want 3000", cfg.Port) + } + + if cfg.Host != "localhost" { + t.Errorf("Host = %q, want %q", cfg.Host, "localhost") + } +} + +func TestGetUserJSON_NotFound(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getResp: &GetUserConfigResponse{Value: nil, Found: false}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg map[string]any + found, err := ch.GetUserJSON(context.Background(), "nonexistent", &cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if found { + t.Error("expected found = false") + } +} + +func TestGetUserJSON_NilOut(t *testing.T) { + t.Parallel() + + client := &AzdClient{userConfigClient: &stubUserConfigService{}} + ch, _ := NewConfigHelper(client) + + _, err := ch.GetUserJSON(context.Background(), "some.path", nil) + if err == nil { + t.Fatal("expected error for nil out parameter") + } +} + +func TestGetUserJSON_InvalidJSON(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getResp: &GetUserConfigResponse{Value: []byte("not json"), Found: true}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg map[string]any + _, err := ch.GetUserJSON(context.Background(), "bad.json", &cfg) + if err == nil { + t.Fatal("expected error for invalid JSON") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonInvalidFormat { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonInvalidFormat) + } +} + +func TestGetUserJSON_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + var cfg map[string]any + _, err := ch.GetUserJSON(context.Background(), "", &cfg) + if err == nil { + t.Fatal("expected error for empty path") + } +} + +// --- SetUserJSON --- + +func TestSetUserJSON_Success(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{} + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.SetUserJSON(context.Background(), "extensions.myext.port", 3000) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestSetUserJSON_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + err := ch.SetUserJSON(context.Background(), "", "value") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +func TestSetUserJSON_NilValue(t *testing.T) { + t.Parallel() + + client := &AzdClient{userConfigClient: &stubUserConfigService{}} + ch, _ := NewConfigHelper(client) + + err := ch.SetUserJSON(context.Background(), "some.path", nil) + if err == nil { + t.Fatal("expected error for nil value") + } +} + +func TestSetUserJSON_GRPCError(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + setErr: errors.New("grpc write error"), + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.SetUserJSON(context.Background(), "some.path", "value") + if err == nil { + t.Fatal("expected error for gRPC failure") + } +} + +func TestSetUserJSON_UnmarshalableValue(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{} + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + // Channels cannot be marshaled to JSON + err := ch.SetUserJSON(context.Background(), "some.path", make(chan int)) + if err == nil { + t.Fatal("expected error for unmarshalable value") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonInvalidFormat { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonInvalidFormat) + } +} + +// --- UnsetUser --- + +func TestUnsetUser_Success(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{} + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.UnsetUser(context.Background(), "some.path") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestUnsetUser_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + err := ch.UnsetUser(context.Background(), "") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +// --- GetEnvString --- + +func TestGetEnvString_Found(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{ + getConfigStringResp: &GetConfigStringResponse{Value: "prod", Found: true}, + } + + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + val, found, err := ch.GetEnvString(context.Background(), "extensions.myext.mode") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !found { + t.Error("expected found = true") + } + + if val != "prod" { + t.Errorf("value = %q, want %q", val, "prod") + } +} + +func TestGetEnvString_NotFound(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{ + getConfigStringResp: &GetConfigStringResponse{Value: "", Found: false}, + } + + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + _, found, err := ch.GetEnvString(context.Background(), "nonexistent") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if found { + t.Error("expected found = false") + } +} + +func TestGetEnvString_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + _, _, err := ch.GetEnvString(context.Background(), "") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +// --- GetEnvJSON --- + +func TestGetEnvJSON_Found(t *testing.T) { + t.Parallel() + + type envConfig struct { + Debug bool `json:"debug"` + } + + data, _ := json.Marshal(envConfig{Debug: true}) + stub := &stubEnvironmentService{ + getConfigResp: &GetConfigResponse{Value: data, Found: true}, + } + + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg envConfig + found, err := ch.GetEnvJSON(context.Background(), "extensions.myext", &cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !found { + t.Error("expected found = true") + } + + if !cfg.Debug { + t.Error("expected Debug = true") + } +} + +func TestGetEnvJSON_NotFound(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{ + getConfigResp: &GetConfigResponse{Value: nil, Found: false}, + } + + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg map[string]any + found, err := ch.GetEnvJSON(context.Background(), "nonexistent", &cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if found { + t.Error("expected found = false") + } +} + +func TestGetEnvJSON_NilOut(t *testing.T) { + t.Parallel() + + client := &AzdClient{environmentClient: &stubEnvironmentService{}} + ch, _ := NewConfigHelper(client) + + _, err := ch.GetEnvJSON(context.Background(), "some.path", nil) + if err == nil { + t.Fatal("expected error for nil out parameter") + } +} + +// --- SetEnvJSON --- + +func TestSetEnvJSON_Success(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{} + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.SetEnvJSON(context.Background(), "extensions.myext.mode", "prod") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestSetEnvJSON_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + err := ch.SetEnvJSON(context.Background(), "", "value") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +func TestSetEnvJSON_NilValue(t *testing.T) { + t.Parallel() + + client := &AzdClient{environmentClient: &stubEnvironmentService{}} + ch, _ := NewConfigHelper(client) + + err := ch.SetEnvJSON(context.Background(), "some.path", nil) + if err == nil { + t.Fatal("expected error for nil value") + } +} + +// --- UnsetEnv --- + +func TestUnsetEnv_Success(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{} + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.UnsetEnv(context.Background(), "some.path") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestUnsetEnv_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + err := ch.UnsetEnv(context.Background(), "") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +// --- MergeJSON --- + +func TestMergeJSON_Basic(t *testing.T) { + t.Parallel() + + base := map[string]any{"a": 1, "b": 2} + override := map[string]any{"b": 3, "c": 4} + + result := MergeJSON(base, override) + + if result["a"] != 1 { + t.Errorf("a = %v, want 1", result["a"]) + } + + if result["b"] != 3 { + t.Errorf("b = %v, want 3 (override wins)", result["b"]) + } + + if result["c"] != 4 { + t.Errorf("c = %v, want 4", result["c"]) + } +} + +func TestMergeJSON_EmptyBase(t *testing.T) { + t.Parallel() + + result := MergeJSON(nil, map[string]any{"x": "y"}) + + if result["x"] != "y" { + t.Errorf("x = %v, want y", result["x"]) + } +} + +func TestMergeJSON_EmptyOverride(t *testing.T) { + t.Parallel() + + result := MergeJSON(map[string]any{"x": "y"}, nil) + + if result["x"] != "y" { + t.Errorf("x = %v, want y", result["x"]) + } +} + +func TestMergeJSON_BothEmpty(t *testing.T) { + t.Parallel() + + result := MergeJSON(nil, nil) + + if len(result) != 0 { + t.Errorf("len(result) = %d, want 0", len(result)) + } +} + +func TestMergeJSON_DoesNotMutateInputs(t *testing.T) { + t.Parallel() + + base := map[string]any{"a": 1} + override := map[string]any{"b": 2} + + _ = MergeJSON(base, override) + + if _, ok := base["b"]; ok { + t.Error("MergeJSON mutated base map") + } + + if _, ok := override["a"]; ok { + t.Error("MergeJSON mutated override map") + } +} + +// --- DeepMergeJSON --- + +func TestDeepMergeJSON_RecursiveMerge(t *testing.T) { + t.Parallel() + + base := map[string]any{ + "server": map[string]any{ + "host": "localhost", + "port": 3000, + }, + "debug": false, + } + + override := map[string]any{ + "server": map[string]any{ + "port": 8080, + "tls": true, + }, + "version": "1.0", + } + + result := DeepMergeJSON(base, override) + + server, ok := result["server"].(map[string]any) + if !ok { + t.Fatal("server should be a map") + } + + if server["host"] != "localhost" { + t.Errorf("server.host = %v, want localhost", server["host"]) + } + + if server["port"] != 8080 { + t.Errorf("server.port = %v, want 8080 (override wins)", server["port"]) + } + + if server["tls"] != true { + t.Errorf("server.tls = %v, want true", server["tls"]) + } + + if result["debug"] != false { + t.Errorf("debug = %v, want false", result["debug"]) + } + + if result["version"] != "1.0" { + t.Errorf("version = %v, want 1.0", result["version"]) + } +} + +func TestDeepMergeJSON_OverrideReplacesNonMap(t *testing.T) { + t.Parallel() + + base := map[string]any{"x": "string-value"} + override := map[string]any{"x": map[string]any{"nested": true}} + + result := DeepMergeJSON(base, override) + + nested, ok := result["x"].(map[string]any) + if !ok { + t.Fatal("override should replace string with map") + } + + if nested["nested"] != true { + t.Errorf("x.nested = %v, want true", nested["nested"]) + } +} + +func TestDeepMergeJSON_DoesNotMutateInputs(t *testing.T) { + t.Parallel() + + base := map[string]any{"a": map[string]any{"x": 1}} + override := map[string]any{"a": map[string]any{"y": 2}} + + _ = DeepMergeJSON(base, override) + + baseA := base["a"].(map[string]any) + if _, ok := baseA["y"]; ok { + t.Error("DeepMergeJSON mutated base nested map") + } +} + +// --- ValidateConfig --- + +func TestValidateConfig_EmptyData(t *testing.T) { + t.Parallel() + + err := ValidateConfig("test.path", nil) + if err == nil { + t.Fatal("expected error for empty data") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonMissing { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonMissing) + } +} + +func TestValidateConfig_InvalidJSON(t *testing.T) { + t.Parallel() + + err := ValidateConfig("test.path", []byte("not json")) + if err == nil { + t.Fatal("expected error for invalid JSON") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonInvalidFormat { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonInvalidFormat) + } +} + +func TestValidateConfig_ValidatorFails(t *testing.T) { + t.Parallel() + + data, _ := json.Marshal(map[string]any{"a": 1}) + failValidator := func(_ any) error { return errors.New("validation failed") } + + err := ValidateConfig("test.path", data, failValidator) + if err == nil { + t.Fatal("expected error from failing validator") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonValidationFailed { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonValidationFailed) + } +} + +func TestValidateConfig_AllValidatorsPass(t *testing.T) { + t.Parallel() + + data, _ := json.Marshal(map[string]any{"a": 1, "b": 2}) + passValidator := func(_ any) error { return nil } + + err := ValidateConfig("test.path", data, passValidator, passValidator) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestValidateConfig_NoValidators(t *testing.T) { + t.Parallel() + + data, _ := json.Marshal(map[string]any{"a": 1}) + + err := ValidateConfig("test.path", data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +// --- RequiredKeys --- + +func TestRequiredKeys_AllPresent(t *testing.T) { + t.Parallel() + + validator := RequiredKeys("host", "port") + value := map[string]any{"host": "localhost", "port": 3000, "extra": true} + + err := validator(value) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestRequiredKeys_MissingKey(t *testing.T) { + t.Parallel() + + validator := RequiredKeys("host", "port") + value := map[string]any{"host": "localhost"} + + err := validator(value) + if err == nil { + t.Fatal("expected error for missing key") + } +} + +func TestRequiredKeys_NotAMap(t *testing.T) { + t.Parallel() + + validator := RequiredKeys("key") + + err := validator("not a map") + if err == nil { + t.Fatal("expected error for non-map value") + } +} + +// --- ConfigError --- + +func TestConfigError_Error(t *testing.T) { + t.Parallel() + + err := &ConfigError{ + Path: "test.path", + Reason: ConfigReasonMissing, + Err: errors.New("not found"), + } + + got := err.Error() + if got == "" { + t.Fatal("Error() returned empty string") + } +} + +func TestConfigError_Unwrap(t *testing.T) { + t.Parallel() + + inner := errors.New("inner error") + err := &ConfigError{ + Path: "test.path", + Reason: ConfigReasonInvalidFormat, + Err: inner, + } + + if !errors.Is(err, inner) { + t.Error("Unwrap should expose inner error via errors.Is") + } +} + +func TestConfigReason_String(t *testing.T) { + t.Parallel() + + tests := []struct { + reason ConfigReason + want string + }{ + {ConfigReasonMissing, "missing"}, + {ConfigReasonInvalidFormat, "invalid_format"}, + {ConfigReasonValidationFailed, "validation_failed"}, + {ConfigReason(99), "unknown"}, + } + + for _, tt := range tests { + if got := tt.reason.String(); got != tt.want { + t.Errorf("ConfigReason(%d).String() = %q, want %q", tt.reason, got, tt.want) + } + } +} diff --git a/cli/azd/pkg/azdext/keyvault_resolver.go b/cli/azd/pkg/azdext/keyvault_resolver.go new file mode 100644 index 00000000000..0f78416656c --- /dev/null +++ b/cli/azd/pkg/azdext/keyvault_resolver.go @@ -0,0 +1,320 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "errors" + "fmt" + "net/http" + "regexp" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azsecrets" + "github.com/azure/azure-dev/cli/azd/pkg/keyvault" +) + +// KeyVaultResolver resolves Azure Key Vault secret references for extension +// scenarios. It uses the extension's [TokenProvider] for authentication and +// the Azure SDK data-plane client for secret retrieval. +// +// Secret references use the akvs:// URI scheme: +// +// akvs://// +// +// Usage: +// +// tp, _ := azdext.NewTokenProvider(ctx, client, nil) +// resolver, _ := azdext.NewKeyVaultResolver(tp, nil) +// value, err := resolver.Resolve(ctx, "akvs://sub-id/my-vault/my-secret") +type KeyVaultResolver struct { + credential azcore.TokenCredential + clientFactory secretClientFactory + opts KeyVaultResolverOptions +} + +// secretClientFactory abstracts secret client creation for testability. +type secretClientFactory func(vaultURL string, credential azcore.TokenCredential) (secretGetter, error) + +// secretGetter abstracts the Azure SDK secret client's GetSecret method. +type secretGetter interface { + GetSecret( + ctx context.Context, + name string, + version string, + options *azsecrets.GetSecretOptions, + ) (azsecrets.GetSecretResponse, error) +} + +// KeyVaultResolverOptions configures a [KeyVaultResolver]. +type KeyVaultResolverOptions struct { + // VaultSuffix overrides the default Key Vault DNS suffix. + // Defaults to "vault.azure.net" (Azure public cloud). + VaultSuffix string + + // ClientFactory overrides the default secret client constructor. + // Useful for testing. When nil, the production [azsecrets.NewClient] is used. + ClientFactory func(vaultURL string, credential azcore.TokenCredential) (secretGetter, error) +} + +// NewKeyVaultResolver creates a [KeyVaultResolver] with the given credential. +// +// credential must not be nil; it is typically a [*TokenProvider] from P1-1. +// If opts is nil, production defaults are used. +func NewKeyVaultResolver(credential azcore.TokenCredential, opts *KeyVaultResolverOptions) (*KeyVaultResolver, error) { + if credential == nil { + return nil, errors.New("azdext.NewKeyVaultResolver: credential must not be nil") + } + + if opts == nil { + opts = &KeyVaultResolverOptions{} + } + + if opts.VaultSuffix == "" { + opts.VaultSuffix = "vault.azure.net" + } + + factory := defaultSecretClientFactory + if opts.ClientFactory != nil { + factory = opts.ClientFactory + } + + return &KeyVaultResolver{ + credential: credential, + clientFactory: factory, + opts: *opts, + }, nil +} + +// defaultSecretClientFactory creates a real Azure SDK secrets client. +func defaultSecretClientFactory(vaultURL string, credential azcore.TokenCredential) (secretGetter, error) { + client, err := azsecrets.NewClient(vaultURL, credential, nil) + if err != nil { + return nil, err + } + + return client, nil +} + +// Resolve fetches the secret value for an akvs:// reference. +// +// The reference must match the format: akvs://// +// +// Returns a [*KeyVaultResolveError] for all domain errors (invalid reference, +// secret not found, authentication failure). No silent fallbacks or hidden retries. +func (r *KeyVaultResolver) Resolve(ctx context.Context, ref string) (string, error) { + if ctx == nil { + return "", errors.New("azdext.KeyVaultResolver.Resolve: context must not be nil") + } + + parsed, err := ParseSecretReference(ref) + if err != nil { + return "", &KeyVaultResolveError{ + Reference: ref, + Reason: ResolveReasonInvalidReference, + Err: err, + } + } + + vaultURL := fmt.Sprintf("https://%s.%s", parsed.VaultName, r.opts.VaultSuffix) + + client, err := r.clientFactory(vaultURL, r.credential) + if err != nil { + return "", &KeyVaultResolveError{ + Reference: ref, + Reason: ResolveReasonClientCreation, + Err: fmt.Errorf("failed to create Key Vault client for %s: %w", vaultURL, err), + } + } + + resp, err := client.GetSecret(ctx, parsed.SecretName, "", nil) + if err != nil { + reason := ResolveReasonAccessDenied + + var respErr *azcore.ResponseError + if errors.As(err, &respErr) { + switch respErr.StatusCode { + case http.StatusNotFound: + reason = ResolveReasonNotFound + case http.StatusForbidden, http.StatusUnauthorized: + reason = ResolveReasonAccessDenied + default: + reason = ResolveReasonServiceError + } + } + + return "", &KeyVaultResolveError{ + Reference: ref, + Reason: reason, + Err: fmt.Errorf( + "failed to retrieve secret %q from vault %q: %w", + parsed.SecretName, + parsed.VaultName, + err, + ), + } + } + + if resp.Value == nil { + return "", &KeyVaultResolveError{ + Reference: ref, + Reason: ResolveReasonNotFound, + Err: fmt.Errorf("secret %q in vault %q has a nil value", parsed.SecretName, parsed.VaultName), + } + } + + return *resp.Value, nil +} + +// ResolveMap resolves a map of key → akvs:// references, returning a map of +// key → resolved secret values. Processing stops at the first error. +// +// Non-akvs:// values are passed through unchanged, so callers can safely +// resolve a mixed map of plain values and secret references. +func (r *KeyVaultResolver) ResolveMap(ctx context.Context, refs map[string]string) (map[string]string, error) { + if ctx == nil { + return nil, errors.New("azdext.KeyVaultResolver.ResolveMap: context must not be nil") + } + + result := make(map[string]string, len(refs)) + + for key, value := range refs { + if !IsSecretReference(value) { + result[key] = value + continue + } + + resolved, err := r.Resolve(ctx, value) + if err != nil { + return nil, fmt.Errorf("azdext.KeyVaultResolver.ResolveMap: key %q: %w", key, err) + } + + result[key] = resolved + } + + return result, nil +} + +// SecretReference represents a parsed akvs:// URI. +type SecretReference struct { + // SubscriptionID is the Azure subscription containing the Key Vault. + SubscriptionID string + + // VaultName is the Key Vault name (not the full URL). + VaultName string + + // SecretName is the name of the secret within the vault. + SecretName string +} + +// IsSecretReference reports whether s uses the akvs:// scheme. +func IsSecretReference(s string) bool { + return keyvault.IsAzureKeyVaultSecret(s) +} + +// vaultNameRe validates Azure Key Vault names per Azure naming rules: +// - 3–24 characters +// - starts with a letter +// - contains only alphanumeric and hyphens +// - does not end with a hyphen +var vaultNameRe = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9-]{1,22}[a-zA-Z0-9]$`) + +// ParseSecretReference parses an akvs:// URI into its components. +// +// Expected format: akvs://// +// +// The vault name is validated against Azure Key Vault naming rules (3–24 +// characters, starts with letter, alphanumeric and hyphens only, does not +// end with a hyphen). +func ParseSecretReference(ref string) (*SecretReference, error) { + parsed, err := keyvault.ParseAzureKeyVaultSecret(ref) + if err != nil { + return nil, err + } + + if strings.TrimSpace(parsed.SubscriptionId) == "" { + return nil, fmt.Errorf("invalid akvs:// reference %q: subscription-id must not be empty", ref) + } + if strings.TrimSpace(parsed.VaultName) == "" { + return nil, fmt.Errorf("invalid akvs:// reference %q: vault-name must not be empty", ref) + } + if !vaultNameRe.MatchString(parsed.VaultName) { + return nil, fmt.Errorf( + "invalid akvs:// reference %q: vault name %q must be 3-24 characters, "+ + "start with a letter, and contain only alphanumeric characters and non-consecutive hyphens", + ref, parsed.VaultName, + ) + } + if strings.TrimSpace(parsed.SecretName) == "" { + return nil, fmt.Errorf("invalid akvs:// reference %q: secret-name must not be empty", ref) + } + + return &SecretReference{ + SubscriptionID: parsed.SubscriptionId, + VaultName: parsed.VaultName, + SecretName: parsed.SecretName, + }, nil +} + +// ResolveReason classifies the cause of a [KeyVaultResolveError]. +type ResolveReason int + +const ( + // ResolveReasonInvalidReference indicates the akvs:// URI is malformed. + ResolveReasonInvalidReference ResolveReason = iota + + // ResolveReasonClientCreation indicates failure to create the Key Vault client. + ResolveReasonClientCreation + + // ResolveReasonNotFound indicates the secret does not exist. + ResolveReasonNotFound + + // ResolveReasonAccessDenied indicates an authentication or authorization failure. + ResolveReasonAccessDenied + + // ResolveReasonServiceError indicates an unexpected Key Vault service error. + ResolveReasonServiceError +) + +// String returns a human-readable label for the reason. +func (r ResolveReason) String() string { + switch r { + case ResolveReasonInvalidReference: + return "invalid_reference" + case ResolveReasonClientCreation: + return "client_creation" + case ResolveReasonNotFound: + return "not_found" + case ResolveReasonAccessDenied: + return "access_denied" + case ResolveReasonServiceError: + return "service_error" + default: + return "unknown" + } +} + +// KeyVaultResolveError is returned when [KeyVaultResolver.Resolve] fails. +type KeyVaultResolveError struct { + // Reference is the original akvs:// URI that was being resolved. + Reference string + + // Reason classifies the failure. + Reason ResolveReason + + // Err is the underlying error. + Err error +} + +func (e *KeyVaultResolveError) Error() string { + return fmt.Sprintf( + "azdext.KeyVaultResolver: %s (ref=%s): %v", + e.Reason, e.Reference, e.Err, + ) +} + +func (e *KeyVaultResolveError) Unwrap() error { + return e.Err +} diff --git a/cli/azd/pkg/azdext/keyvault_resolver_test.go b/cli/azd/pkg/azdext/keyvault_resolver_test.go new file mode 100644 index 00000000000..8f311afee9b --- /dev/null +++ b/cli/azd/pkg/azdext/keyvault_resolver_test.go @@ -0,0 +1,577 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "errors" + "net/http" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azsecrets" +) + +// stubSecretGetter is a test double for the Key Vault data-plane client. +type stubSecretGetter struct { + resp azsecrets.GetSecretResponse + err error +} + +func (s *stubSecretGetter) GetSecret( + _ context.Context, _ string, _ string, _ *azsecrets.GetSecretOptions, +) (azsecrets.GetSecretResponse, error) { + return s.resp, s.err +} + +// stubSecretFactory returns a factory that always returns the given stubSecretGetter. +func stubSecretFactory(g secretGetter, factoryErr error) func(string, azcore.TokenCredential) (secretGetter, error) { + return func(_ string, _ azcore.TokenCredential) (secretGetter, error) { + if factoryErr != nil { + return nil, factoryErr + } + return g, nil + } +} + +// --- NewKeyVaultResolver --- + +func TestNewKeyVaultResolver_NilCredential(t *testing.T) { + t.Parallel() + + _, err := NewKeyVaultResolver(nil, nil) + if err == nil { + t.Fatal("expected error for nil credential") + } +} + +func TestNewKeyVaultResolver_Defaults(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, err := NewKeyVaultResolver(cred, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if resolver.opts.VaultSuffix != "vault.azure.net" { + t.Errorf("VaultSuffix = %q, want %q", resolver.opts.VaultSuffix, "vault.azure.net") + } +} + +func TestNewKeyVaultResolver_CustomSuffix(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, err := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + VaultSuffix: "vault.azure.cn", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if resolver.opts.VaultSuffix != "vault.azure.cn" { + t.Errorf("VaultSuffix = %q, want %q", resolver.opts.VaultSuffix, "vault.azure.cn") + } +} + +// --- IsSecretReference --- + +func TestIsSecretReference(t *testing.T) { + t.Parallel() + + tests := []struct { + input string + want bool + }{ + {"akvs://sub/vault/secret", true}, + {"akvs://", true}, + {"AKVS://sub/vault/secret", false}, // case-sensitive + {"https://vault.azure.net", false}, + {"", false}, + } + + for _, tt := range tests { + if got := IsSecretReference(tt.input); got != tt.want { + t.Errorf("IsSecretReference(%q) = %v, want %v", tt.input, got, tt.want) + } + } +} + +// --- ParseSecretReference --- + +func TestParseSecretReference_Valid(t *testing.T) { + t.Parallel() + + ref, err := ParseSecretReference("akvs://sub-123/my-vault/my-secret") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if ref.SubscriptionID != "sub-123" { + t.Errorf("SubscriptionID = %q, want %q", ref.SubscriptionID, "sub-123") + } + if ref.VaultName != "my-vault" { + t.Errorf("VaultName = %q, want %q", ref.VaultName, "my-vault") + } + if ref.SecretName != "my-secret" { + t.Errorf("SecretName = %q, want %q", ref.SecretName, "my-secret") + } +} + +func TestParseSecretReference_NotAkvsScheme(t *testing.T) { + t.Parallel() + + _, err := ParseSecretReference("https://vault.azure.net/secrets/x") + if err == nil { + t.Fatal("expected error for non-akvs scheme") + } +} + +func TestParseSecretReference_TooFewParts(t *testing.T) { + t.Parallel() + + _, err := ParseSecretReference("akvs://sub/vault") + if err == nil { + t.Fatal("expected error for two-part ref") + } +} + +func TestParseSecretReference_TooManyParts(t *testing.T) { + t.Parallel() + + _, err := ParseSecretReference("akvs://sub/vault/secret/extra") + if err == nil { + t.Fatal("expected error for four-part ref") + } +} + +func TestParseSecretReference_EmptyComponent(t *testing.T) { + t.Parallel() + + cases := []string{ + "akvs:///vault/secret", // empty subscription + "akvs://sub//secret", // empty vault + "akvs://sub/vault/", // empty secret + "akvs:// /vault/secret", // whitespace subscription + "akvs://sub/ /secret", // whitespace vault + "akvs://sub/vault/ ", // whitespace secret + } + + for _, ref := range cases { + _, err := ParseSecretReference(ref) + if err == nil { + t.Errorf("ParseSecretReference(%q) expected error, got nil", ref) + } + } +} + +// --- Resolve --- + +func TestResolve_Success(t *testing.T) { + t.Parallel() + + secretValue := "super-secret-value" + getter := &stubSecretGetter{ + resp: azsecrets.GetSecretResponse{ + Secret: azsecrets.Secret{ + Value: &secretValue, + }, + }, + } + + cred := &stubCredential{} + resolver, err := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + val, err := resolver.Resolve(context.Background(), "akvs://sub-id/my-vault/my-secret") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if val != secretValue { + t.Errorf("Resolve() = %q, want %q", val, secretValue) + } +} + +func TestResolve_NilContext(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(&stubSecretGetter{}, nil), + }) + + //nolint:staticcheck // intentionally testing nil context + //lint:ignore SA1012 intentionally testing nil context handling + _, err := resolver.Resolve(nil, "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for nil context") + } +} + +func TestResolve_InvalidReference(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(&stubSecretGetter{}, nil), + }) + + _, err := resolver.Resolve(context.Background(), "not-akvs://x") + if err == nil { + t.Fatal("expected error for invalid reference") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonInvalidReference { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonInvalidReference) + } +} + +func TestResolve_ClientCreationFailure(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(nil, errors.New("connection refused")), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for client creation failure") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonClientCreation { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonClientCreation) + } +} + +func TestResolve_SecretNotFound(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: &azcore.ResponseError{StatusCode: http.StatusNotFound}, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/missing-secret") + if err == nil { + t.Fatal("expected error for missing secret") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonNotFound { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonNotFound) + } +} + +func TestResolve_AccessDenied(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: &azcore.ResponseError{StatusCode: http.StatusForbidden}, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for forbidden access") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonAccessDenied { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonAccessDenied) + } +} + +func TestResolve_Unauthorized(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: &azcore.ResponseError{StatusCode: http.StatusUnauthorized}, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for unauthorized access") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonAccessDenied { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonAccessDenied) + } +} + +func TestResolve_ServiceError(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: &azcore.ResponseError{StatusCode: http.StatusInternalServerError}, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for server error") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonServiceError { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonServiceError) + } +} + +func TestResolve_NilValue(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + resp: azsecrets.GetSecretResponse{ + Secret: azsecrets.Secret{ + Value: nil, + }, + }, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for nil secret value") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonNotFound { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonNotFound) + } +} + +func TestResolve_NonResponseError(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: errors.New("network timeout"), + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for network failure") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + // Non-ResponseError defaults to access_denied + if resolveErr.Reason != ResolveReasonAccessDenied { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonAccessDenied) + } +} + +// --- ResolveMap --- + +func TestResolveMap_MixedValues(t *testing.T) { + t.Parallel() + + secretValue := "resolved-secret" + getter := &stubSecretGetter{ + resp: azsecrets.GetSecretResponse{ + Secret: azsecrets.Secret{ + Value: &secretValue, + }, + }, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + input := map[string]string{ + "plain": "hello-world", + "secret": "akvs://sub/vault/secret", + } + + result, err := resolver.ResolveMap(context.Background(), input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result["plain"] != "hello-world" { + t.Errorf("result[plain] = %q, want %q", result["plain"], "hello-world") + } + + if result["secret"] != secretValue { + t.Errorf("result[secret] = %q, want %q", result["secret"], secretValue) + } +} + +func TestResolveMap_Empty(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(&stubSecretGetter{}, nil), + }) + + result, err := resolver.ResolveMap(context.Background(), map[string]string{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result) != 0 { + t.Errorf("len(result) = %d, want 0", len(result)) + } +} + +func TestResolveMap_ErrorStopsProcessing(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: &azcore.ResponseError{StatusCode: http.StatusNotFound}, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + input := map[string]string{ + "secret": "akvs://sub/vault/missing", + } + + _, err := resolver.ResolveMap(context.Background(), input) + if err == nil { + t.Fatal("expected error when resolution fails") + } +} + +func TestResolveMap_NilContext(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(&stubSecretGetter{}, nil), + }) + + //nolint:staticcheck // intentionally testing nil context + //lint:ignore SA1012 intentionally testing nil context handling + _, err := resolver.ResolveMap(nil, map[string]string{"k": "v"}) + if err == nil { + t.Fatal("expected error for nil context") + } +} + +// --- Error types --- + +func TestKeyVaultResolveError_Error(t *testing.T) { + t.Parallel() + + err := &KeyVaultResolveError{ + Reference: "akvs://sub/vault/secret", + Reason: ResolveReasonNotFound, + Err: errors.New("secret not found"), + } + + got := err.Error() + if got == "" { + t.Fatal("Error() returned empty string") + } +} + +func TestKeyVaultResolveError_Unwrap(t *testing.T) { + t.Parallel() + + inner := errors.New("inner error") + err := &KeyVaultResolveError{ + Reference: "akvs://sub/vault/secret", + Reason: ResolveReasonServiceError, + Err: inner, + } + + if !errors.Is(err, inner) { + t.Error("Unwrap should expose inner error via errors.Is") + } +} + +func TestResolveReason_String(t *testing.T) { + t.Parallel() + + tests := []struct { + reason ResolveReason + want string + }{ + {ResolveReasonInvalidReference, "invalid_reference"}, + {ResolveReasonClientCreation, "client_creation"}, + {ResolveReasonNotFound, "not_found"}, + {ResolveReasonAccessDenied, "access_denied"}, + {ResolveReasonServiceError, "service_error"}, + {ResolveReason(99), "unknown"}, + } + + for _, tt := range tests { + if got := tt.reason.String(); got != tt.want { + t.Errorf("ResolveReason(%d).String() = %q, want %q", tt.reason, got, tt.want) + } + } +} diff --git a/cli/azd/pkg/azdext/logger.go b/cli/azd/pkg/azdext/logger.go new file mode 100644 index 00000000000..d0d156ca565 --- /dev/null +++ b/cli/azd/pkg/azdext/logger.go @@ -0,0 +1,175 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "io" + "log/slog" + "os" + "strings" +) + +// LoggerOptions configures [SetupLogging] and [NewLogger]. +type LoggerOptions struct { + // Debug enables debug-level logging. When false, messages below Info are + // suppressed. If not set explicitly, [NewLogger] checks the AZD_DEBUG + // environment variable. + Debug bool + // Structured selects JSON output when true, human-readable text when false. + Structured bool + // Writer overrides the output destination. Defaults to os.Stderr. + Writer io.Writer +} + +// SetupLogging configures the process-wide default [slog.Logger]. +// It is typically called once at startup (for example from +// [NewExtensionRootCommand]'s PersistentPreRunE callback). +// +// Calling SetupLogging is optional — [NewLogger] works without it and creates +// loggers that inherit from [slog.Default]. SetupLogging is provided for +// extensions that want explicit control over the global log level and format. +func SetupLogging(opts LoggerOptions) { + handler := newHandler(opts) + slog.SetDefault(slog.New(handler)) +} + +// Logger provides component-scoped structured logging built on [log/slog]. +// +// Each Logger carries a "component" attribute so log lines can be filtered or +// routed by subsystem. Additional context can be attached via [Logger.With], +// [Logger.WithComponent], or [Logger.WithOperation]. +// +// Logger writes to stderr by default and never writes to stdout, so it does +// not interfere with command output or JSON-mode piping. +type Logger struct { + slogger *slog.Logger + component string +} + +// NewLogger creates a Logger scoped to the given component name. +// +// If the AZD_DEBUG environment variable is set to a truthy value ("1", "true", +// "yes") and opts.Debug is false, debug logging is enabled automatically. This +// lets extension authors respect the framework's debug flag without extra +// plumbing. +// +// When opts is omitted (zero value), the logger uses Info level with text +// format on stderr. +func NewLogger(component string, opts ...LoggerOptions) *Logger { + var o LoggerOptions + if len(opts) > 0 { + o = opts[0] + } + + // Auto-detect debug from environment when not explicitly set. + if !o.Debug { + o.Debug = isDebugEnv() + } + + handler := newHandler(o) + base := slog.New(handler).With("component", component) + + return &Logger{ + slogger: base, + component: component, + } +} + +// Component returns the component name this logger was created with. +func (l *Logger) Component() string { + return l.component +} + +// Debug logs a message at debug level with optional key-value pairs. +func (l *Logger) Debug(msg string, args ...any) { + l.slogger.Debug(msg, args...) +} + +// Info logs a message at info level with optional key-value pairs. +func (l *Logger) Info(msg string, args ...any) { + l.slogger.Info(msg, args...) +} + +// Warn logs a message at warn level with optional key-value pairs. +func (l *Logger) Warn(msg string, args ...any) { + l.slogger.Warn(msg, args...) +} + +// Error logs a message at error level with optional key-value pairs. +func (l *Logger) Error(msg string, args ...any) { + l.slogger.Error(msg, args...) +} + +// With returns a new Logger that includes the given key-value pairs in every +// subsequent log entry. Keys must be strings; values can be any type +// supported by [slog]. +// +// Example: +// +// l := logger.With("request_id", reqID) +// l.Info("processing") // includes component + request_id +func (l *Logger) With(args ...any) *Logger { + return &Logger{ + slogger: l.slogger.With(args...), + component: l.component, + } +} + +// WithComponent returns a new Logger with a different component name. The +// original component is preserved as "parent_component". +func (l *Logger) WithComponent(name string) *Logger { + return &Logger{ + slogger: l.slogger.With("parent_component", l.component, "component", name), + component: name, + } +} + +// WithOperation returns a new Logger with an "operation" attribute. +func (l *Logger) WithOperation(name string) *Logger { + return &Logger{ + slogger: l.slogger.With("operation", name), + component: l.component, + } +} + +// Slogger returns the underlying [*slog.Logger] for advanced use cases such +// as passing to libraries that accept a standard slog logger. +func (l *Logger) Slogger() *slog.Logger { + return l.slogger +} + +// --------------------------------------------------------------------------- +// Internal helpers +// --------------------------------------------------------------------------- + +// newHandler creates an slog.Handler from LoggerOptions. +func newHandler(opts LoggerOptions) slog.Handler { + w := opts.Writer + if w == nil { + w = os.Stderr + } + + level := slog.LevelInfo + if opts.Debug { + level = slog.LevelDebug + } + + handlerOpts := &slog.HandlerOptions{Level: level} + + if opts.Structured { + return slog.NewJSONHandler(w, handlerOpts) + } + return slog.NewTextHandler(w, handlerOpts) +} + +// isDebugEnv checks the AZD_DEBUG environment variable. +// +// Security note: AZD_DEBUG enables verbose logging that may include +// request details, configuration paths, and internal state. It should +// NOT be enabled in production deployments. The variable is intended +// for local development and CI debugging only. +func isDebugEnv() bool { + v := strings.ToLower(os.Getenv("AZD_DEBUG")) + return v == "1" || v == "true" || v == "yes" +} diff --git a/cli/azd/pkg/azdext/logger_test.go b/cli/azd/pkg/azdext/logger_test.go new file mode 100644 index 00000000000..9e21b763ada --- /dev/null +++ b/cli/azd/pkg/azdext/logger_test.go @@ -0,0 +1,295 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "bytes" + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// NewLogger — basic construction +// --------------------------------------------------------------------------- + +func TestNewLogger_DefaultOptions(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("test-component", LoggerOptions{Writer: &buf}) + + require.NotNil(t, logger) + require.Equal(t, "test-component", logger.Component()) +} + +func TestNewLogger_ZeroOptions(t *testing.T) { + // Zero-value opts should not panic (writes to stderr). + logger := NewLogger("safe") + require.NotNil(t, logger) + require.Equal(t, "safe", logger.Component()) +} + +func TestNewLogger_NoOpts(t *testing.T) { + // Calling without variadic opts should not panic. + logger := NewLogger("minimal") + require.NotNil(t, logger) +} + +// --------------------------------------------------------------------------- +// Log levels — Info +// --------------------------------------------------------------------------- + +func TestLogger_Info(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("mycomp", LoggerOptions{Writer: &buf}) + + logger.Info("hello world", "key", "val") + + output := buf.String() + require.Contains(t, output, "hello world") + require.Contains(t, output, "key=val") + require.Contains(t, output, "component=mycomp") +} + +// --------------------------------------------------------------------------- +// Log levels — Debug +// --------------------------------------------------------------------------- + +func TestLogger_Debug_Enabled(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("dbg", LoggerOptions{Debug: true, Writer: &buf}) + + logger.Debug("debug message", "detail", "x") + + require.Contains(t, buf.String(), "debug message") + require.Contains(t, buf.String(), "detail=x") +} + +func TestLogger_Debug_Disabled(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("dbg", LoggerOptions{Debug: false, Writer: &buf}) + + logger.Debug("should not appear") + + require.Empty(t, buf.String()) +} + +func TestLogger_Debug_AZD_DEBUG_EnvVar(t *testing.T) { + t.Setenv("AZD_DEBUG", "true") + + var buf bytes.Buffer + logger := NewLogger("env-debug", LoggerOptions{Writer: &buf}) + + logger.Debug("from env var") + + require.Contains(t, buf.String(), "from env var") +} + +func TestLogger_Debug_AZD_DEBUG_EnvVar_One(t *testing.T) { + t.Setenv("AZD_DEBUG", "1") + + var buf bytes.Buffer + logger := NewLogger("env-one", LoggerOptions{Writer: &buf}) + + logger.Debug("debug via 1") + + require.Contains(t, buf.String(), "debug via 1") +} + +func TestLogger_Debug_AZD_DEBUG_EnvVar_Yes(t *testing.T) { + t.Setenv("AZD_DEBUG", "yes") + + var buf bytes.Buffer + logger := NewLogger("env-yes", LoggerOptions{Writer: &buf}) + + logger.Debug("debug via yes") + + require.Contains(t, buf.String(), "debug via yes") +} + +func TestLogger_Debug_AZD_DEBUG_Unset(t *testing.T) { + t.Setenv("AZD_DEBUG", "") + + var buf bytes.Buffer + logger := NewLogger("env-empty", LoggerOptions{Writer: &buf}) + + logger.Debug("hidden") + + require.Empty(t, buf.String()) +} + +// --------------------------------------------------------------------------- +// Log levels — Warn / Error +// --------------------------------------------------------------------------- + +func TestLogger_Warn(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("warn-test", LoggerOptions{Writer: &buf}) + + logger.Warn("something concerning", "retries", 3) + + require.Contains(t, buf.String(), "something concerning") + require.Contains(t, buf.String(), "retries=3") +} + +func TestLogger_Error(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("err-test", LoggerOptions{Writer: &buf}) + + logger.Error("bad thing happened", "code", 500) + + require.Contains(t, buf.String(), "bad thing happened") + require.Contains(t, buf.String(), "code=500") +} + +// --------------------------------------------------------------------------- +// Structured (JSON) output +// --------------------------------------------------------------------------- + +func TestLogger_StructuredJSON(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("json-comp", LoggerOptions{Structured: true, Writer: &buf}) + + logger.Info("structured entry", "env", "prod") + + // Each line should be valid JSON. + var parsed map[string]any + err := json.Unmarshal(buf.Bytes(), &parsed) + require.NoError(t, err) + require.Equal(t, "structured entry", parsed["msg"]) + require.Equal(t, "prod", parsed["env"]) + require.Equal(t, "json-comp", parsed["component"]) +} + +// --------------------------------------------------------------------------- +// Context chaining — With +// --------------------------------------------------------------------------- + +func TestLogger_With(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("base", LoggerOptions{Writer: &buf}) + + child := logger.With("request_id", "abc-123") + child.Info("processing") + + output := buf.String() + require.Contains(t, output, "request_id=abc-123") + require.Contains(t, output, "component=base") + require.Contains(t, output, "processing") + + // Child should have the same component. + require.Equal(t, "base", child.Component()) +} + +func TestLogger_With_ChainMultiple(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("chain", LoggerOptions{Writer: &buf}) + + child := logger.With("a", "1").With("b", "2") + child.Info("chained") + + output := buf.String() + require.Contains(t, output, "a=1") + require.Contains(t, output, "b=2") +} + +// --------------------------------------------------------------------------- +// Context chaining — WithComponent +// --------------------------------------------------------------------------- + +func TestLogger_WithComponent(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("parent", LoggerOptions{Structured: true, Writer: &buf}) + + child := logger.WithComponent("child-subsystem") + child.Info("from child") + + require.Equal(t, "child-subsystem", child.Component()) + + var parsed map[string]any + err := json.Unmarshal(buf.Bytes(), &parsed) + require.NoError(t, err) + require.Equal(t, "child-subsystem", parsed["component"]) + require.Equal(t, "parent", parsed["parent_component"]) +} + +// --------------------------------------------------------------------------- +// Context chaining — WithOperation +// --------------------------------------------------------------------------- + +func TestLogger_WithOperation(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("ops", LoggerOptions{Writer: &buf}) + + child := logger.WithOperation("deploy") + child.Info("starting deploy") + + output := buf.String() + require.Contains(t, output, "operation=deploy") + require.Equal(t, "ops", child.Component()) +} + +// --------------------------------------------------------------------------- +// Slogger accessor +// --------------------------------------------------------------------------- + +func TestLogger_Slogger(t *testing.T) { + logger := NewLogger("access", LoggerOptions{Writer: &bytes.Buffer{}}) + require.NotNil(t, logger.Slogger()) +} + +// --------------------------------------------------------------------------- +// SetupLogging — global logger configuration +// --------------------------------------------------------------------------- + +func TestSetupLogging_DoesNotPanic(t *testing.T) { + // SetupLogging modifies slog.Default which is global state. + // We only verify it does not panic here. + var buf bytes.Buffer + SetupLogging(LoggerOptions{Debug: true, Structured: true, Writer: &buf}) + + // Restore a sensible default after the test. + SetupLogging(LoggerOptions{Writer: &bytes.Buffer{}}) +} + +// --------------------------------------------------------------------------- +// isDebugEnv internal helper +// --------------------------------------------------------------------------- + +func TestIsDebugEnv_Truthy(t *testing.T) { + truthy := []string{"1", "true", "TRUE", "True", "yes", "YES", "Yes"} + for _, v := range truthy { + t.Run(v, func(t *testing.T) { + t.Setenv("AZD_DEBUG", v) + require.True(t, isDebugEnv()) + }) + } +} + +func TestIsDebugEnv_Falsy(t *testing.T) { + falsy := []string{"", "0", "false", "no", "maybe"} + for _, v := range falsy { + t.Run("value="+v, func(t *testing.T) { + t.Setenv("AZD_DEBUG", v) + require.False(t, isDebugEnv()) + }) + } +} + +// --------------------------------------------------------------------------- +// Text format verification +// --------------------------------------------------------------------------- + +func TestLogger_TextFormat_ContainsLevel(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("lvl", LoggerOptions{Writer: &buf}) + + logger.Info("test level") + + output := buf.String() + require.True(t, + strings.Contains(output, "INFO") || strings.Contains(output, "level=INFO"), + "expected level indicator in text output: %s", output) +} diff --git a/cli/azd/pkg/azdext/mcp_security.go b/cli/azd/pkg/azdext/mcp_security.go index c9c2122eaeb..70a99261463 100644 --- a/cli/azd/pkg/azdext/mcp_security.go +++ b/cli/azd/pkg/azdext/mcp_security.go @@ -6,6 +6,7 @@ package azdext import ( "fmt" "net" + "net/http" "net/url" "os" "path/filepath" @@ -23,6 +24,8 @@ type MCPSecurityPolicy struct { allowedBasePaths []string blockedCIDRs []*net.IPNet blockedHosts map[string]bool + // onBlocked is invoked whenever a URL or path is blocked, for audit logging. + onBlocked func(violation string) // lookupHost is used for DNS resolution; override in tests. lookupHost func(string) ([]string, error) } @@ -42,12 +45,7 @@ func (p *MCPSecurityPolicy) BlockMetadataEndpoints() *MCPSecurityPolicy { p.mu.Lock() defer p.mu.Unlock() p.blockMetadata = true - for _, host := range []string{ - "169.254.169.254", - "fd00:ec2::254", - "metadata.google.internal", - "100.100.100.200", - } { + for _, host := range ssrfMetadataHosts { p.blockedHosts[strings.ToLower(host)] = true } return p @@ -60,23 +58,7 @@ func (p *MCPSecurityPolicy) BlockPrivateNetworks() *MCPSecurityPolicy { p.mu.Lock() defer p.mu.Unlock() p.blockPrivate = true - for _, cidr := range []string{ - "0.0.0.0/8", // "this" network (reaches loopback on Linux/macOS) - "10.0.0.0/8", // RFC 1918 private - "172.16.0.0/12", // RFC 1918 private - "192.168.0.0/16", // RFC 1918 private - "127.0.0.0/8", // loopback - "100.64.0.0/10", // RFC 6598 shared/CGNAT (internal in cloud environments) - "169.254.0.0/16", // IPv4 link-local - "::1/128", // IPv6 loopback - "::/128", // IPv6 unspecified (reaches loopback) - "fc00::/7", // IPv6 unique local addresses (RFC 4193, equiv of RFC 1918) - "fe80::/10", // IPv6 link-local - "2002::/16", // 6to4 relay (deprecated RFC 7526; can embed private IPv4) - "2001::/32", // Teredo tunneling (deprecated; can embed private IPv4) - "64:ff9b::/96", // NAT64 well-known prefix (RFC 6052; embeds IPv4 in last 32 bits) - "64:ff9b:1::/48", // NAT64 local-use prefix (RFC 8215; embeds IPv4 in last 32 bits) - } { + for _, cidr := range ssrfBlockedCIDRs { _, ipNet, err := net.ParseCIDR(cidr) if err == nil { p.blockedCIDRs = append(p.blockedCIDRs, ipNet) @@ -111,6 +93,16 @@ func (p *MCPSecurityPolicy) ValidatePathsWithinBase(basePaths ...string) *MCPSec return p } +// OnBlocked registers a callback invoked whenever a URL or path check fails. +// The callback receives a human-readable description of the violation. +// This is intended for audit logging; the callback must not block. +func (p *MCPSecurityPolicy) OnBlocked(fn func(violation string)) *MCPSecurityPolicy { + p.mu.Lock() + defer p.mu.Unlock() + p.onBlocked = fn + return p +} + // isLocalhostHost returns true if the host is localhost or a loopback address. func isLocalhostHost(host string) bool { h := strings.ToLower(host) @@ -125,8 +117,18 @@ func isLocalhostHost(host string) bool { // Returns an error describing the violation, or nil if allowed. func (p *MCPSecurityPolicy) CheckURL(rawURL string) error { p.mu.RLock() - defer p.mu.RUnlock() + fn := p.onBlocked + err := p.checkURLCore(rawURL) + p.mu.RUnlock() + if fn != nil && err != nil { + fn(err.Error()) + } + + return err +} + +func (p *MCPSecurityPolicy) checkURLCore(rawURL string) error { u, err := url.Parse(rawURL) if err != nil { return fmt.Errorf("invalid URL: %w", err) @@ -162,7 +164,8 @@ func (p *MCPSecurityPolicy) CheckURL(rawURL string) error { if err != nil { // Fail-closed: if DNS resolution fails, block the request. // This prevents SSRF bypasses via DNS rebinding or transient failures. - return fmt.Errorf("DNS resolution failed for host %s: %w", host, err) + blockErr := fmt.Errorf("DNS resolution failed for host %s: %w", host, err) + return blockErr } for _, addr := range addrs { if p.blockedHosts[strings.ToLower(addr)] { @@ -180,70 +183,8 @@ func (p *MCPSecurityPolicy) CheckURL(rawURL string) error { } func (p *MCPSecurityPolicy) checkIP(ip net.IP, originalHost string) error { - for _, cidr := range p.blockedCIDRs { - if cidr.Contains(ip) { - return fmt.Errorf("blocked IP %s (CIDR %s) for host %s", ip, cidr, originalHost) - } - } - - if p.blockPrivate { - // Catch encoding variants (e.g., IPv4-compatible IPv6 like ::127.0.0.1) - // that may not match CIDR entries due to byte-length mismatch. - if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsUnspecified() { - return fmt.Errorf("blocked IP %s (private/loopback/link-local) for host %s", ip, originalHost) - } - - // Handle encoding variants that Go's net.IP methods don't classify, by extracting - // the embedded IPv4 address and re-checking it against all blocked ranges. - if len(ip) == net.IPv6len && ip.To4() == nil { - // IPv4-compatible (::x.x.x.x, RFC 4291 §2.5.5.1): first 12 bytes are zero. - isV4Compatible := true - for i := 0; i < 12; i++ { - if ip[i] != 0 { - isV4Compatible = false - break - } - } - if isV4Compatible && (ip[12] != 0 || ip[13] != 0 || ip[14] != 0 || ip[15] != 0) { - v4 := net.IPv4(ip[12], ip[13], ip[14], ip[15]) - for _, cidr := range p.blockedCIDRs { - if cidr.Contains(v4) { - return fmt.Errorf("blocked IP %s (IPv4-compatible %s, CIDR %s) for host %s", - ip, v4, cidr, originalHost) - } - } - if v4.IsLoopback() || v4.IsPrivate() || v4.IsLinkLocalUnicast() || v4.IsUnspecified() { - return fmt.Errorf("blocked IP %s (IPv4-compatible %s, private/loopback) for host %s", - ip, v4, originalHost) - } - } - - // IPv4-translated (::ffff:0:x.x.x.x, RFC 2765 §4.2.1): bytes 0-7 zero, - // bytes 8-9 = 0xFF 0xFF, bytes 10-11 = 0x00 0x00, bytes 12-15 = IPv4. - // Distinct from IPv4-mapped (bytes 10-11 = 0xFF), so To4() returns nil. - isV4Translated := ip[8] == 0xFF && ip[9] == 0xFF && ip[10] == 0x00 && ip[11] == 0x00 - if isV4Translated { - for i := 0; i < 8; i++ { - if ip[i] != 0 { - isV4Translated = false - break - } - } - } - if isV4Translated && (ip[12] != 0 || ip[13] != 0 || ip[14] != 0 || ip[15] != 0) { - v4 := net.IPv4(ip[12], ip[13], ip[14], ip[15]) - for _, cidr := range p.blockedCIDRs { - if cidr.Contains(v4) { - return fmt.Errorf("blocked IP %s (IPv4-translated %s, CIDR %s) for host %s", - ip, v4, cidr, originalHost) - } - } - if v4.IsLoopback() || v4.IsPrivate() || v4.IsLinkLocalUnicast() || v4.IsUnspecified() { - return fmt.Errorf("blocked IP %s (IPv4-translated %s, private/loopback) for host %s", - ip, v4, originalHost) - } - } - } + if _, detail, blocked := ssrfCheckIP(ip, originalHost, p.blockedCIDRs, p.blockPrivate); blocked { + return fmt.Errorf("%s", detail) } return nil @@ -251,10 +192,32 @@ func (p *MCPSecurityPolicy) checkIP(ip net.IP, originalHost string) error { // CheckPath validates a file path against the security policy. // Resolves symlinks and checks for directory traversal. +// +// SECURITY NOTE (TOCTOU): This check is inherently susceptible to +// time-of-check-to-time-of-use races — the filesystem state may change between +// the validation here and the actual file access by the caller. +// +// Mitigations callers should consider: +// - Use O_NOFOLLOW when opening files after validation. +// - Use file-descriptor-based approaches (openat2 with RESOLVE_BENEATH on +// Linux 5.6+) where possible. +// - Open the file immediately after validation and re-verify the resolved +// path via /proc/self/fd/N or fstat before processing. +// - Avoid writing to directories that untrusted users can modify. func (p *MCPSecurityPolicy) CheckPath(path string) error { p.mu.RLock() - defer p.mu.RUnlock() + fn := p.onBlocked + err := p.checkPathCore(path) + p.mu.RUnlock() + + if fn != nil && err != nil { + fn(err.Error()) + } + + return err +} +func (p *MCPSecurityPolicy) checkPathCore(path string) error { if len(p.allowedBasePaths) == 0 { return nil } @@ -348,3 +311,48 @@ func resolveExistingPrefix(p string) string { } } } + +// redirectBlockedHosts lists hostnames that HTTP redirects must never follow. +// This covers cloud metadata services that attackers commonly target via +// redirect-based SSRF (the initial request hits an allowed host which 302s +// to the metadata endpoint). +var redirectBlockedHosts = map[string]bool{ + "169.254.169.254": true, + "fd00:ec2::254": true, + "metadata.google.internal": true, + "100.100.100.200": true, +} + +// SSRFSafeRedirect is an http.Client CheckRedirect function that blocks +// redirects to cloud metadata endpoints, private/loopback addresses, and +// non-HTTP(S) schemes. Assign it to http.Client.CheckRedirect. +func SSRFSafeRedirect(req *http.Request, via []*http.Request) error { + host := strings.ToLower(req.URL.Hostname()) + + // Block known metadata endpoints. + if redirectBlockedHosts[host] { + return fmt.Errorf("redirect to blocked metadata host: %s", host) + } + + // Block redirects to private/loopback IPs. + if ip := net.ParseIP(host); ip != nil { + if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsUnspecified() { + return fmt.Errorf("redirect to private/loopback IP: %s", host) + } + } + + // Block non-HTTP(S) scheme redirects (e.g., file://, gopher://). + switch req.URL.Scheme { + case "http", "https": + // allowed + default: + return fmt.Errorf("redirect to disallowed scheme: %s", req.URL.Scheme) + } + + // Enforce standard redirect limit. + if len(via) >= 10 { + return fmt.Errorf("stopped after %d redirects", len(via)) + } + + return nil +} diff --git a/cli/azd/pkg/azdext/output.go b/cli/azd/pkg/azdext/output.go new file mode 100644 index 00000000000..ada93d09569 --- /dev/null +++ b/cli/azd/pkg/azdext/output.go @@ -0,0 +1,273 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "encoding/json" + "fmt" + "io" + "os" + "strings" + + "github.com/fatih/color" +) + +// OutputFormat represents the output format for extension commands. +type OutputFormat string + +const ( + // OutputFormatDefault is human-readable text with optional color. + OutputFormatDefault OutputFormat = "default" + // OutputFormatJSON outputs structured JSON for machine consumption. + OutputFormatJSON OutputFormat = "json" +) + +// ParseOutputFormat converts a string to an OutputFormat. +// Returns OutputFormatDefault for unrecognized values and a non-nil error. +func ParseOutputFormat(s string) (OutputFormat, error) { + switch strings.ToLower(s) { + case "default", "": + return OutputFormatDefault, nil + case "json": + return OutputFormatJSON, nil + default: + return OutputFormatDefault, fmt.Errorf("invalid output format %q (valid: default, json)", s) + } +} + +// OutputOptions configures an [Output] instance. +type OutputOptions struct { + // Format controls the output style. Defaults to OutputFormatDefault. + Format OutputFormat + // Writer is the destination for normal output. Defaults to os.Stdout. + Writer io.Writer + // ErrWriter is the destination for error/warning output. Defaults to os.Stderr. + ErrWriter io.Writer +} + +// Output provides formatted, format-aware output for extension commands. +// In default mode it writes human-readable text with ANSI color; in JSON mode +// it writes structured JSON objects to stdout and suppresses decorative output. +// +// Output is safe for use from a single goroutine. If concurrent use is needed +// callers should synchronize externally. +type Output struct { + writer io.Writer + errWriter io.Writer + format OutputFormat + + // Color printers — configured once at construction. + successColor *color.Color + warningColor *color.Color + errorColor *color.Color + infoColor *color.Color + headerColor *color.Color + dimColor *color.Color +} + +// NewOutput creates an Output configured by opts. +// If opts.Writer or opts.ErrWriter are nil they default to os.Stdout / os.Stderr. +func NewOutput(opts OutputOptions) *Output { + w := opts.Writer + if w == nil { + w = os.Stdout + } + ew := opts.ErrWriter + if ew == nil { + ew = os.Stderr + } + + return &Output{ + writer: w, + errWriter: ew, + format: opts.Format, + successColor: color.New(color.FgGreen), + warningColor: color.New(color.FgYellow), + errorColor: color.New(color.FgRed), + infoColor: color.New(color.FgCyan), + headerColor: color.New(color.Bold), + dimColor: color.New(color.Faint), + } +} + +// IsJSON returns true when the output format is JSON. +// Callers can use this to skip decorative output that is only relevant in +// human-readable mode. +func (o *Output) IsJSON() bool { + return o.format == OutputFormatJSON +} + +// Success prints a success message prefixed with a green check mark. +// In JSON mode the call is a no-op (use [Output.JSON] for structured data). +func (o *Output) Success(format string, args ...any) { + if o.IsJSON() { + return + } + msg := sanitizeOutputText(fmt.Sprintf(format, args...)) + o.successColor.Fprintf(o.writer, "(✓) Done: %s\n", msg) +} + +// Warning prints a warning message prefixed with a yellow exclamation mark. +// Warnings are written to ErrWriter in both default and JSON mode so they are +// visible even when stdout is piped through a JSON consumer. +func (o *Output) Warning(format string, args ...any) { + msg := fmt.Sprintf(format, args...) + if o.IsJSON() { + // In JSON mode emit a structured warning to stderr. + _ = json.NewEncoder(o.errWriter).Encode(map[string]string{ + "level": "warning", + "message": msg, + }) + return + } + o.warningColor.Fprintf(o.errWriter, "(!) Warning: %s\n", sanitizeOutputText(msg)) +} + +// Error prints an error message prefixed with a red cross. +// Errors are always written to ErrWriter. +func (o *Output) Error(format string, args ...any) { + msg := fmt.Sprintf(format, args...) + if o.IsJSON() { + _ = json.NewEncoder(o.errWriter).Encode(map[string]string{ + "level": "error", + "message": msg, + }) + return + } + o.errorColor.Fprintf(o.errWriter, "(✗) Error: %s\n", sanitizeOutputText(msg)) +} + +// Info prints an informational message prefixed with an info symbol. +// In JSON mode the call is a no-op (use [Output.JSON] for structured data). +func (o *Output) Info(format string, args ...any) { + if o.IsJSON() { + return + } + msg := sanitizeOutputText(fmt.Sprintf(format, args...)) + o.infoColor.Fprintf(o.writer, "(i) %s\n", msg) +} + +// Message prints an undecorated message to stdout. +// In JSON mode the call is a no-op. +func (o *Output) Message(format string, args ...any) { + if o.IsJSON() { + return + } + msg := sanitizeOutputText(fmt.Sprintf(format, args...)) + fmt.Fprintln(o.writer, msg) +} + +// JSON writes data as a pretty-printed JSON object to stdout. +// It is active in all output modes so callers can unconditionally emit +// structured payloads (in default mode the JSON is still human-readable). +func (o *Output) JSON(data any) error { + enc := json.NewEncoder(o.writer) + enc.SetIndent("", " ") + if err := enc.Encode(data); err != nil { + return fmt.Errorf("output: failed to encode JSON: %w", err) + } + return nil +} + +// Table prints a formatted text table with headers and rows. +// In JSON mode the table is emitted as a JSON array of objects instead. +// +// headers defines the column names. Each row is a slice of cell values +// with the same length as headers. Rows with fewer cells are padded with +// empty strings; extra cells are silently ignored. +func (o *Output) Table(headers []string, rows [][]string) { + if len(headers) == 0 { + return + } + + if o.IsJSON() { + o.tableJSON(headers, rows) + return + } + + o.tableText(headers, rows) +} + +// tableJSON emits the table as a JSON array of objects keyed by header name. +func (o *Output) tableJSON(headers []string, rows [][]string) { + out := make([]map[string]string, 0, len(rows)) + for _, row := range rows { + obj := make(map[string]string, len(headers)) + for i, h := range headers { + if i < len(row) { + obj[h] = row[i] + } else { + obj[h] = "" + } + } + out = append(out, obj) + } + _ = o.JSON(out) +} + +// tableText renders an aligned text table with a header separator. +func (o *Output) tableText(headers []string, rows [][]string) { + // Calculate column widths. + widths := make([]int, len(headers)) + for i, h := range headers { + widths[i] = len(h) + } + for _, row := range rows { + for i := range headers { + if i < len(row) && len(row[i]) > widths[i] { + widths[i] = len(row[i]) + } + } + } + + // Print header row. + for i, h := range headers { + if i > 0 { + fmt.Fprint(o.writer, " ") + } + o.headerColor.Fprintf(o.writer, "%-*s", widths[i], h) + } + fmt.Fprintln(o.writer) + + // Print separator. + for i, w := range widths { + if i > 0 { + fmt.Fprint(o.writer, " ") + } + fmt.Fprint(o.writer, strings.Repeat("─", w)) + } + fmt.Fprintln(o.writer) + + // Print data rows. + for _, row := range rows { + for i := range headers { + if i > 0 { + fmt.Fprint(o.writer, " ") + } + cell := "" + if i < len(row) { + cell = sanitizeOutputText(row[i]) + } + fmt.Fprintf(o.writer, "%-*s", widths[i], cell) + } + fmt.Fprintln(o.writer) + } +} + +// sanitizeOutputText replaces CR, LF, and other ASCII control characters +// (except TAB) with a space so that untrusted values embedded in text-mode +// output cannot forge log lines or inject terminal escape sequences. +// JSON-mode output is NOT sanitized here because json.Encoder already +// escapes control characters in string values. +func sanitizeOutputText(s string) string { + return strings.Map(func(r rune) rune { + if r == '\t' { + return r + } + if r < 0x20 || r == 0x7F { + return ' ' + } + return r + }, s) +} diff --git a/cli/azd/pkg/azdext/output_test.go b/cli/azd/pkg/azdext/output_test.go new file mode 100644 index 00000000000..4a31fe7ccae --- /dev/null +++ b/cli/azd/pkg/azdext/output_test.go @@ -0,0 +1,358 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "bytes" + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// ParseOutputFormat +// --------------------------------------------------------------------------- + +func TestParseOutputFormat(t *testing.T) { + tests := []struct { + name string + input string + expected OutputFormat + expectErr bool + }{ + {name: "default string", input: "default", expected: OutputFormatDefault}, + {name: "empty string", input: "", expected: OutputFormatDefault}, + {name: "json lowercase", input: "json", expected: OutputFormatJSON}, + {name: "JSON uppercase", input: "JSON", expected: OutputFormatJSON}, + {name: "Json mixed case", input: "Json", expected: OutputFormatJSON}, + {name: "DEFAULT uppercase", input: "DEFAULT", expected: OutputFormatDefault}, + {name: "invalid format", input: "xml", expected: OutputFormatDefault, expectErr: true}, + {name: "invalid format yaml", input: "yaml", expected: OutputFormatDefault, expectErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseOutputFormat(tt.input) + if tt.expectErr { + require.Error(t, err) + require.Contains(t, err.Error(), "invalid output format") + } else { + require.NoError(t, err) + } + require.Equal(t, tt.expected, got) + }) + } +} + +// --------------------------------------------------------------------------- +// NewOutput defaults +// --------------------------------------------------------------------------- + +func TestNewOutput_DefaultWriters(t *testing.T) { + out := NewOutput(OutputOptions{}) + require.NotNil(t, out) + // Default format should be "default" (zero-value of OutputFormat). + require.False(t, out.IsJSON()) +} + +func TestNewOutput_JSONMode(t *testing.T) { + out := NewOutput(OutputOptions{Format: OutputFormatJSON}) + require.True(t, out.IsJSON()) +} + +// --------------------------------------------------------------------------- +// Success +// --------------------------------------------------------------------------- + +func TestOutput_Success_DefaultFormat(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + out.Success("deployed %s", "myapp") + + // Should contain the message text (color codes may wrap it). + require.Contains(t, buf.String(), "Done: deployed myapp") +} + +func TestOutput_Success_JSONFormat_IsNoop(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf, Format: OutputFormatJSON}) + + out.Success("should not appear") + + require.Empty(t, buf.String()) +} + +// --------------------------------------------------------------------------- +// Warning +// --------------------------------------------------------------------------- + +func TestOutput_Warning_DefaultFormat(t *testing.T) { + var errBuf bytes.Buffer + out := NewOutput(OutputOptions{ErrWriter: &errBuf}) + + out.Warning("deprecated %s", "v1") + + require.Contains(t, errBuf.String(), "Warning: deprecated v1") +} + +func TestOutput_Warning_JSONFormat(t *testing.T) { + var errBuf bytes.Buffer + out := NewOutput(OutputOptions{ErrWriter: &errBuf, Format: OutputFormatJSON}) + + out.Warning("api deprecated") + + var parsed map[string]string + err := json.Unmarshal(errBuf.Bytes(), &parsed) + require.NoError(t, err) + require.Equal(t, "warning", parsed["level"]) + require.Equal(t, "api deprecated", parsed["message"]) +} + +// --------------------------------------------------------------------------- +// Error +// --------------------------------------------------------------------------- + +func TestOutput_Error_DefaultFormat(t *testing.T) { + var errBuf bytes.Buffer + out := NewOutput(OutputOptions{ErrWriter: &errBuf}) + + out.Error("connection failed: %s", "timeout") + + require.Contains(t, errBuf.String(), "Error: connection failed: timeout") +} + +func TestOutput_Error_JSONFormat(t *testing.T) { + var errBuf bytes.Buffer + out := NewOutput(OutputOptions{ErrWriter: &errBuf, Format: OutputFormatJSON}) + + out.Error("disk full") + + var parsed map[string]string + err := json.Unmarshal(errBuf.Bytes(), &parsed) + require.NoError(t, err) + require.Equal(t, "error", parsed["level"]) + require.Equal(t, "disk full", parsed["message"]) +} + +// --------------------------------------------------------------------------- +// Info +// --------------------------------------------------------------------------- + +func TestOutput_Info_DefaultFormat(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + out.Info("fetching %d items", 5) + + require.Contains(t, buf.String(), "fetching 5 items") +} + +func TestOutput_Info_JSONFormat_IsNoop(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf, Format: OutputFormatJSON}) + + out.Info("hidden") + + require.Empty(t, buf.String()) +} + +// --------------------------------------------------------------------------- +// Message +// --------------------------------------------------------------------------- + +func TestOutput_Message_DefaultFormat(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + out.Message("plain text %d", 42) + + require.Equal(t, "plain text 42\n", buf.String()) +} + +func TestOutput_Message_JSONFormat_IsNoop(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf, Format: OutputFormatJSON}) + + out.Message("should not appear") + + require.Empty(t, buf.String()) +} + +// --------------------------------------------------------------------------- +// JSON +// --------------------------------------------------------------------------- + +func TestOutput_JSON_Struct(t *testing.T) { + type result struct { + Name string `json:"name"` + Count int `json:"count"` + } + + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + err := out.JSON(result{Name: "test", Count: 7}) + require.NoError(t, err) + + var decoded result + require.NoError(t, json.Unmarshal(buf.Bytes(), &decoded)) + require.Equal(t, "test", decoded.Name) + require.Equal(t, 7, decoded.Count) +} + +func TestOutput_JSON_Map(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + err := out.JSON(map[string]string{"key": "value"}) + require.NoError(t, err) + + var decoded map[string]string + require.NoError(t, json.Unmarshal(buf.Bytes(), &decoded)) + require.Equal(t, "value", decoded["key"]) +} + +func TestOutput_JSON_Unmarshalable(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + err := out.JSON(make(chan int)) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to encode JSON") +} + +func TestOutput_JSON_PrettyPrinted(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + err := out.JSON(map[string]int{"a": 1}) + require.NoError(t, err) + + // Verify indentation is present (pretty-printed). + require.Contains(t, buf.String(), " ") +} + +// --------------------------------------------------------------------------- +// Table — default format +// --------------------------------------------------------------------------- + +func TestOutput_Table_DefaultFormat(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + headers := []string{"Name", "Status"} + rows := [][]string{ + {"api", "running"}, + {"web", "stopped"}, + } + + out.Table(headers, rows) + + text := buf.String() + require.Contains(t, text, "Name") + require.Contains(t, text, "Status") + require.Contains(t, text, "api") + require.Contains(t, text, "running") + require.Contains(t, text, "web") + require.Contains(t, text, "stopped") + + // Separator line should be present. + require.Contains(t, text, "─") +} + +func TestOutput_Table_EmptyHeaders(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + out.Table(nil, [][]string{{"a"}}) + + require.Empty(t, buf.String()) +} + +func TestOutput_Table_EmptyRows(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + out.Table([]string{"Name"}, nil) + + // Header + separator should still be printed. + text := buf.String() + require.Contains(t, text, "Name") + require.Contains(t, text, "─") +} + +func TestOutput_Table_ShortRow(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + // Row has fewer cells than headers — should pad with empty strings. + out.Table([]string{"A", "B", "C"}, [][]string{{"only-a"}}) + + text := buf.String() + require.Contains(t, text, "only-a") + // No panic from short row. + lines := strings.Split(strings.TrimSpace(text), "\n") + require.Len(t, lines, 3) // header + separator + 1 data row +} + +func TestOutput_Table_ColumnAlignment(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + headers := []string{"ID", "LongerName"} + rows := [][]string{ + {"1", "short"}, + {"2", "a-much-longer-value"}, + } + + out.Table(headers, rows) + + lines := strings.Split(strings.TrimSpace(buf.String()), "\n") + require.GreaterOrEqual(t, len(lines), 3) + + // All separator dashes should align with header width. + sepLine := lines[1] + require.NotEmpty(t, sepLine) +} + +// --------------------------------------------------------------------------- +// Table — JSON format +// --------------------------------------------------------------------------- + +func TestOutput_Table_JSONFormat(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf, Format: OutputFormatJSON}) + + headers := []string{"Service", "Port"} + rows := [][]string{ + {"api", "8080"}, + {"web", "3000"}, + } + + out.Table(headers, rows) + + var decoded []map[string]string + require.NoError(t, json.Unmarshal(buf.Bytes(), &decoded)) + require.Len(t, decoded, 2) + require.Equal(t, "api", decoded[0]["Service"]) + require.Equal(t, "8080", decoded[0]["Port"]) + require.Equal(t, "web", decoded[1]["Service"]) + require.Equal(t, "3000", decoded[1]["Port"]) +} + +func TestOutput_Table_JSONFormat_ShortRow(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf, Format: OutputFormatJSON}) + + out.Table([]string{"A", "B"}, [][]string{{"only-a"}}) + + var decoded []map[string]string + require.NoError(t, json.Unmarshal(buf.Bytes(), &decoded)) + require.Len(t, decoded, 1) + require.Equal(t, "only-a", decoded[0]["A"]) + require.Equal(t, "", decoded[0]["B"]) +} diff --git a/cli/azd/pkg/azdext/pagination.go b/cli/azd/pkg/azdext/pagination.go index 6f95645686d..9061d3fbd4e 100644 --- a/cli/azd/pkg/azdext/pagination.go +++ b/cli/azd/pkg/azdext/pagination.go @@ -14,10 +14,19 @@ import ( "strings" ) +const ( + // defaultMaxPages is the default upper bound on pages fetched by [Pager.Collect]. + // Individual callers can override this via [PagerOptions.MaxPages]. + // A value of 0 means unlimited (no cap), which is the default for manual + // NextPage iteration. Collect uses this default when MaxPages is unset. + defaultMaxPages = 500 +) + const ( // maxPageResponseSize limits the maximum size of a single page response // body to prevent excessive memory consumption from malicious or - // misconfigured servers. + // misconfigured servers. 10 MB is intentionally above typical Azure list + // payloads while still bounding memory use. maxPageResponseSize int64 = 10 << 20 // 10 MB // maxErrorBodySize limits the size of error response bodies captured @@ -44,6 +53,8 @@ type Pager[T any] struct { done bool opts PagerOptions originHost string // host of the initial URL for SSRF protection + pageCount int // number of pages fetched so far + truncated bool } // PageResponse is a single page returned by [Pager.NextPage]. @@ -59,6 +70,18 @@ type PageResponse[T any] struct { type PagerOptions struct { // Method overrides the HTTP method used for page requests. Defaults to GET. Method string + + // MaxPages limits the maximum number of pages that [Pager.Collect] will + // fetch. When set to a positive value, Collect stops after fetching that + // many pages. A value of 0 means unlimited (no cap) for manual NextPage + // iteration; Collect applies [defaultMaxPages] when this is 0. + MaxPages int + + // MaxItems limits the maximum total items that [Pager.Collect] will + // accumulate. When the collected items reach this count, Collect stops + // and returns the items gathered so far (truncated to MaxItems). + // A value of 0 means unlimited (no cap). + MaxItems int } // HTTPDoer abstracts the HTTP call so that [ResilientClient] or any @@ -108,7 +131,11 @@ func NewPager[T any](client HTTPDoer, firstURL string, opts *PagerOptions) *Page } // NewPagerFromHTTPClient creates a [Pager] backed by a standard [*http.Client]. +// If client is nil, [http.DefaultClient] is used. func NewPagerFromHTTPClient[T any](client *http.Client, firstURL string, opts *PagerOptions) *Pager[T] { + if client == nil { + client = http.DefaultClient + } return NewPager[T](&stdHTTPDoer{client: client}, firstURL, opts) } @@ -117,6 +144,12 @@ func (p *Pager[T]) More() bool { return !p.done && p.nextURL != "" } +// Truncated reports whether the last [Collect] call stopped early due to +// MaxPages or MaxItems limits. +func (p *Pager[T]) Truncated() bool { + return p.truncated +} + // NextPage fetches the next page of results. Returns an error if the request // fails, the response is not 2xx, or the body cannot be decoded. // @@ -145,7 +178,7 @@ func (p *Pager[T]) NextPage(ctx context.Context) (*PageResponse[T], error) { return nil, &PaginationError{ StatusCode: resp.StatusCode, URL: p.nextURL, - Body: string(body), + Body: sanitizeErrorBody(string(body)), } } @@ -170,6 +203,9 @@ func (p *Pager[T]) NextPage(ctx context.Context) (*PageResponse[T], error) { p.nextURL = page.NextLink } + // Track page count for MaxPages enforcement in Collect. + p.pageCount++ + return &page, nil } @@ -199,12 +235,23 @@ func (p *Pager[T]) validateNextLink(nextLink string) error { } // Collect is a convenience method that fetches all remaining pages and -// returns all items in a single slice. Use with caution on large result sets. +// returns all items in a single slice. +// +// To prevent unbounded memory growth from runaway pagination, Collect +// enforces [PagerOptions.MaxPages] (defaults to [defaultMaxPages] when +// unset) and [PagerOptions.MaxItems]. When either limit is reached, +// iteration stops and the items collected so far are returned. // // If NextPage returns both page data and an error (e.g. rejected nextLink), // the page data is included in the returned slice before returning the error. func (p *Pager[T]) Collect(ctx context.Context) ([]T, error) { var all []T + p.truncated = false + + maxPages := p.opts.MaxPages + if maxPages <= 0 { + maxPages = defaultMaxPages + } for p.More() { page, err := p.NextPage(ctx) @@ -214,11 +261,28 @@ func (p *Pager[T]) Collect(ctx context.Context) ([]T, error) { if err != nil { return all, err } + + // Enforce MaxItems: truncate and stop if exceeded. + if p.opts.MaxItems > 0 && len(all) >= p.opts.MaxItems { + if len(all) > p.opts.MaxItems { + all = all[:p.opts.MaxItems] + } + p.truncated = true + break + } + + // Enforce MaxPages: stop after collecting the configured number of pages. + if p.pageCount >= maxPages { + p.truncated = true + break + } } return all, nil } +const maxPaginationErrorBodyLen = 1024 + // PaginationError is returned when a page request receives a non-2xx response. type PaginationError struct { StatusCode int @@ -229,6 +293,40 @@ type PaginationError struct { func (e *PaginationError) Error() string { return fmt.Sprintf( "azdext.Pager: page request returned HTTP %d (url=%s)", - e.StatusCode, e.URL, + e.StatusCode, redactURL(e.URL), ) } + +func sanitizeErrorBody(body string) string { + if len(body) > maxPaginationErrorBodyLen { + body = body[:maxPaginationErrorBodyLen] + "...[truncated]" + } + return stripControlChars(body) +} + +func stripControlChars(s string) string { + var b strings.Builder + b.Grow(len(s)) + for _, r := range s { + if r < 0x20 && r != '\t' { + b.WriteRune(' ') + } else if r == 0x7F { + b.WriteRune(' ') + } else { + b.WriteRune(r) + } + } + return b.String() +} + +// redactURL strips query parameters and fragments from a URL to avoid leaking +// tokens, SAS signatures, or other secrets in log/error messages. +func redactURL(rawURL string) string { + u, err := url.Parse(rawURL) + if err != nil { + return "" + } + u.RawQuery = "" + u.Fragment = "" + return u.String() +} diff --git a/cli/azd/pkg/azdext/pagination_test.go b/cli/azd/pkg/azdext/pagination_test.go index 1f8c3537279..58e80ef9f4c 100644 --- a/cli/azd/pkg/azdext/pagination_test.go +++ b/cli/azd/pkg/azdext/pagination_test.go @@ -7,6 +7,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "io" "net/http" "strings" @@ -481,3 +482,106 @@ func TestPager_CollectWithSSRFError(t *testing.T) { t.Errorf("all = %v, want [a b] (partial results before SSRF error)", all) } } + +func TestPager_TruncatedByMaxPages(t *testing.T) { + t.Parallel() + + var responses []*doerResponse + for i := 1; i <= 5; i++ { + nextLink := "" + if i < 5 { + nextLink = fmt.Sprintf("https://example.com/api?page=%d", i+1) + } + body := pageJSON([]int{i}, nextLink) + responses = append(responses, &doerResponse{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{}, + }, + }) + } + + doer := &mockDoer{responses: responses} + pager := NewPager[int](doer, "https://example.com/api?page=1", &PagerOptions{MaxPages: 3}) + + all, err := pager.Collect(context.Background()) + if err != nil { + t.Fatalf("Collect failed: %v", err) + } + + if len(all) != 3 { + t.Errorf("len(all) = %d, want 3", len(all)) + } + + if !pager.Truncated() { + t.Error("Truncated() = false, want true (stopped at MaxPages)") + } +} + +func TestPager_TruncatedByMaxItems(t *testing.T) { + t.Parallel() + + var responses []*doerResponse + for i := 0; i < 3; i++ { + items := []int{i*4 + 1, i*4 + 2, i*4 + 3, i*4 + 4} + nextLink := "" + if i < 2 { + nextLink = fmt.Sprintf("https://example.com/api?page=%d", i+2) + } + body := pageJSON(items, nextLink) + responses = append(responses, &doerResponse{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{}, + }, + }) + } + + doer := &mockDoer{responses: responses} + pager := NewPager[int](doer, "https://example.com/api?page=1", &PagerOptions{MaxItems: 5}) + + all, err := pager.Collect(context.Background()) + if err != nil { + t.Fatalf("Collect failed: %v", err) + } + + if len(all) != 5 { + t.Errorf("len(all) = %d, want 5", len(all)) + } + + if !pager.Truncated() { + t.Error("Truncated() = false, want true (stopped at MaxItems)") + } +} + +func TestPager_NotTruncatedOnNaturalEnd(t *testing.T) { + t.Parallel() + + body := pageJSON([]string{"a", "b"}, "") + doer := &mockDoer{ + responses: []*doerResponse{ + {resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{}, + }}, + }, + } + + pager := NewPager[string](doer, "https://example.com/api", nil) + + all, err := pager.Collect(context.Background()) + if err != nil { + t.Fatalf("Collect failed: %v", err) + } + + if len(all) != 2 { + t.Errorf("len(all) = %d, want 2", len(all)) + } + + if pager.Truncated() { + t.Error("Truncated() = true, want false (natural end)") + } +} diff --git a/cli/azd/pkg/azdext/process.go b/cli/azd/pkg/azdext/process.go new file mode 100644 index 00000000000..5fdf4aa5e3d --- /dev/null +++ b/cli/azd/pkg/azdext/process.go @@ -0,0 +1,175 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "fmt" + "os" + "runtime" + "strings" +) + +// --------------------------------------------------------------------------- +// P3-5: Cross-platform process detection +// --------------------------------------------------------------------------- + +// ProcessInfo contains information about a running process. +type ProcessInfo struct { + // PID is the process identifier. + PID int + // Name is the process name (executable basename without extension). + Name string + // Executable is the full path to the process executable, if available. + Executable string + // Running is true if the process was found and appears to be alive. + Running bool +} + +// IsProcessRunning checks whether a process with the given PID exists and +// is still running. +// +// Platform behavior: +// - Unix (Linux/macOS): Sends signal 0 to the process. If the process +// exists (even if owned by another user), this returns true. If the +// process does not exist, it returns false. This does NOT verify that +// the process is the expected one (PID reuse is possible). +// - Windows: Opens the process with PROCESS_QUERY_LIMITED_INFORMATION +// access and checks the exit code. If the process handle is valid and +// the exit code is STILL_ACTIVE, returns true. +// +// Note: PID reuse can cause false positives on all platforms. For critical +// use cases, combine PID checks with process name verification using +// [GetProcessInfo]. +// +// Returns false if the PID is invalid (≤ 0). +func IsProcessRunning(pid int) bool { + if pid <= 0 { + return false + } + return isProcessRunningOS(pid) +} + +// GetProcessInfo retrieves information about the process with the given PID. +// +// Platform behavior: +// - Linux: Reads /proc//comm, /proc//exe. +// - macOS: Uses ps(1) to query process info. +// - Windows: Uses QueryFullProcessImageName via Windows API. +// +// Returns a [ProcessInfo] with Running=false if the process does not exist +// or cannot be queried (e.g., insufficient permissions). +func GetProcessInfo(pid int) ProcessInfo { + if pid <= 0 { + return ProcessInfo{PID: pid, Running: false} + } + return getProcessInfoOS(pid) +} + +// CurrentProcessInfo returns [ProcessInfo] for the current process. +func CurrentProcessInfo() ProcessInfo { + pid := os.Getpid() + exe, _ := os.Executable() + + name := "" + if exe != "" { + name = extractBaseName(exe) + } + + return ProcessInfo{ + PID: pid, + Name: name, + Executable: exe, + Running: true, + } +} + +// ParentProcessInfo returns [ProcessInfo] for the parent of the current +// process. +// +// Platform behavior: +// - All platforms: Uses os.Getppid() to obtain the parent PID, then +// delegates to [GetProcessInfo]. +// - On orphaned processes (parent PID = 1 on Unix), the returned info +// describes the init/launchd process. +func ParentProcessInfo() ProcessInfo { + return GetProcessInfo(os.Getppid()) +} + +// FindProcessByName searches for running processes with the given name. +// The search is case-insensitive and matches the executable basename +// (without file extension on Windows). +// +// Platform behavior: +// - Linux: Scans /proc/*/comm. +// - macOS: Uses ps(1) to list processes. +// - Windows: Uses CreateToolhelp32Snapshot to enumerate processes. +// +// Returns a slice of matching [ProcessInfo]. If no processes are found, +// returns an empty (non-nil) slice. +// +// This function is best-effort: some processes may be inaccessible due to +// permissions. +func FindProcessByName(name string) []ProcessInfo { + if name == "" { + return []ProcessInfo{} + } + return findProcessByNameOS(name) +} + +// ProcessEnvironment describes the process execution context for diagnostics. +type ProcessEnvironment struct { + // PID is the current process ID. + PID int + // PPID is the parent process ID. + PPID int + // Executable is the current process executable path. + Executable string + // WorkingDir is the current working directory. + WorkingDir string + // OS is the operating system (runtime.GOOS). + OS string + // Arch is the CPU architecture (runtime.GOARCH). + Arch string + // NumCPU is the number of logical CPUs available. + NumCPU int +} + +// GetProcessEnvironment collects process execution context useful for +// diagnostics, logging, and support information. +func GetProcessEnvironment() ProcessEnvironment { + exe, _ := os.Executable() + cwd, _ := os.Getwd() + + return ProcessEnvironment{ + PID: os.Getpid(), + PPID: os.Getppid(), + Executable: exe, + WorkingDir: cwd, + OS: runtime.GOOS, + Arch: runtime.GOARCH, + NumCPU: runtime.NumCPU(), + } +} + +// String returns a human-readable summary of the process environment. +func (pe ProcessEnvironment) String() string { + return fmt.Sprintf("pid=%d ppid=%d os=%s arch=%s cpus=%d cwd=%s exe=%s", + pe.PID, pe.PPID, pe.OS, pe.Arch, pe.NumCPU, pe.WorkingDir, pe.Executable) +} + +// --------------------------------------------------------------------------- +// Internal shared helpers +// --------------------------------------------------------------------------- + +// extractBaseName returns the base name of a path without extension. +func extractBaseName(path string) string { + // Handle both Unix and Windows separators. + base := path + if idx := strings.LastIndexAny(path, `/\`); idx >= 0 { + base = path[idx+1:] + } + // Remove common extensions. + base = strings.TrimSuffix(base, ".exe") + return base +} diff --git a/cli/azd/pkg/azdext/process_darwin.go b/cli/azd/pkg/azdext/process_darwin.go new file mode 100644 index 00000000000..4c523a5ab1a --- /dev/null +++ b/cli/azd/pkg/azdext/process_darwin.go @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//go:build darwin + +package azdext + +import ( + "os" + "os/exec" + "strconv" + "strings" + "syscall" +) + +// isProcessRunningOS checks if a process is running on macOS using signal 0. +func isProcessRunningOS(pid int) bool { + proc, err := os.FindProcess(pid) + if err != nil { + return false + } + err = proc.Signal(syscall.Signal(0)) + return err == nil +} + +// getProcessInfoOS retrieves process info on macOS using ps(1). +func getProcessInfoOS(pid int) ProcessInfo { + info := ProcessInfo{PID: pid} + + // Use ps to get process name and executable path. + cmd := exec.Command("ps", "-p", strconv.Itoa(pid), "-o", "comm=") + output, err := cmd.Output() + if err != nil { + return info // Process does not exist or is inaccessible. + } + + info.Name = extractBaseName(strings.TrimSpace(string(output))) + info.Running = true + + // Get full command path. + cmd = exec.Command("ps", "-p", strconv.Itoa(pid), "-o", "args=") + output, err = cmd.Output() + if err == nil { + args := strings.TrimSpace(string(output)) + if fields := strings.Fields(args); len(fields) > 0 { + info.Executable = fields[0] + } + } + + return info +} + +// findProcessByNameOS searches for processes by name on macOS using ps(1). +func findProcessByNameOS(name string) []ProcessInfo { + // ps -ax -o pid=,comm= lists all processes with PID and command name. + cmd := exec.Command("ps", "-ax", "-o", "pid=,comm=") + output, err := cmd.Output() + if err != nil { + return []ProcessInfo{} + } + + nameLower := strings.ToLower(name) + var results []ProcessInfo + + for _, line := range strings.Split(string(output), "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + + // Format: " PID COMM" + fields := strings.Fields(line) + if len(fields) < 2 { + continue + } + + pid, err := strconv.Atoi(fields[0]) + if err != nil { + continue + } + + // The comm field is the rest of the line after PID. + procPath := strings.Join(fields[1:], " ") + procName := extractBaseName(procPath) + + if strings.EqualFold(procName, nameLower) { + results = append(results, ProcessInfo{ + PID: pid, + Name: procName, + Executable: procPath, + Running: true, + }) + } + } + + if results == nil { + return []ProcessInfo{} + } + return results +} diff --git a/cli/azd/pkg/azdext/process_linux.go b/cli/azd/pkg/azdext/process_linux.go new file mode 100644 index 00000000000..85daedaacf3 --- /dev/null +++ b/cli/azd/pkg/azdext/process_linux.go @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//go:build !windows && !darwin + +package azdext + +import ( + "fmt" + "os" + "strconv" + "strings" + "syscall" +) + +// isProcessRunningOS checks if a process is running on Linux using signal 0. +func isProcessRunningOS(pid int) bool { + proc, err := os.FindProcess(pid) + if err != nil { + return false + } + // Signal 0 does not send a signal but performs error checking. + // If the process exists, err is nil. If it doesn't, err is non-nil. + err = proc.Signal(syscall.Signal(0)) + return err == nil +} + +// getProcessInfoOS retrieves process info on Linux via /proc. +func getProcessInfoOS(pid int) ProcessInfo { + info := ProcessInfo{PID: pid} + + // Read process name from /proc//comm. + commPath := fmt.Sprintf("/proc/%d/comm", pid) + commData, err := os.ReadFile(commPath) + if err != nil { + return info // Process does not exist or is inaccessible. + } + info.Name = strings.TrimSpace(string(commData)) + info.Running = true + + // Read executable symlink from /proc//exe. + exePath := fmt.Sprintf("/proc/%d/exe", pid) + exe, err := os.Readlink(exePath) + if err == nil { + info.Executable = exe + } + + return info +} + +// findProcessByNameOS searches for processes by name on Linux via /proc. +func findProcessByNameOS(name string) []ProcessInfo { + entries, err := os.ReadDir("/proc") + if err != nil { + return []ProcessInfo{} + } + + nameLower := strings.ToLower(name) + var results []ProcessInfo + + for _, entry := range entries { + if !entry.IsDir() { + continue + } + pid, err := strconv.Atoi(entry.Name()) + if err != nil { + continue // Not a PID directory. + } + + commPath := fmt.Sprintf("/proc/%d/comm", pid) + commData, err := os.ReadFile(commPath) + if err != nil { + continue + } + + procName := strings.TrimSpace(string(commData)) + if strings.EqualFold(procName, nameLower) { + info := ProcessInfo{ + PID: pid, + Name: procName, + Running: true, + } + // Try to get executable path. + exePath := fmt.Sprintf("/proc/%d/exe", pid) + if exe, err := os.Readlink(exePath); err == nil { + info.Executable = exe + } + results = append(results, info) + } + } + + if results == nil { + return []ProcessInfo{} + } + return results +} diff --git a/cli/azd/pkg/azdext/process_test.go b/cli/azd/pkg/azdext/process_test.go new file mode 100644 index 00000000000..c59de779795 --- /dev/null +++ b/cli/azd/pkg/azdext/process_test.go @@ -0,0 +1,221 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "os" + "runtime" + "strings" + "testing" +) + +// --------------------------------------------------------------------------- +// IsProcessRunning +// --------------------------------------------------------------------------- + +func TestIsProcessRunning_CurrentProcess(t *testing.T) { + pid := os.Getpid() + if !IsProcessRunning(pid) { + t.Errorf("IsProcessRunning(%d) = false for current process, want true", pid) + } +} + +func TestIsProcessRunning_InvalidPID(t *testing.T) { + if IsProcessRunning(0) { + t.Error("IsProcessRunning(0) = true, want false") + } + if IsProcessRunning(-1) { + t.Error("IsProcessRunning(-1) = true, want false") + } +} + +func TestIsProcessRunning_NonexistentPID(t *testing.T) { + // PID 99999999 is extremely unlikely to exist. + if IsProcessRunning(99999999) { + t.Error("IsProcessRunning(99999999) = true, want false") + } +} + +// --------------------------------------------------------------------------- +// GetProcessInfo +// --------------------------------------------------------------------------- + +func TestGetProcessInfo_CurrentProcess(t *testing.T) { + pid := os.Getpid() + info := GetProcessInfo(pid) + if !info.Running { + t.Errorf("GetProcessInfo(%d).Running = false for current process, want true", pid) + } + if info.PID != pid { + t.Errorf("GetProcessInfo(%d).PID = %d, want %d", pid, info.PID, pid) + } + // Name should be non-empty for the current process. + if info.Name == "" { + t.Errorf("GetProcessInfo(%d).Name is empty, want non-empty", pid) + } +} + +func TestGetProcessInfo_InvalidPID(t *testing.T) { + info := GetProcessInfo(-1) + if info.Running { + t.Error("GetProcessInfo(-1).Running = true, want false") + } +} + +func TestGetProcessInfo_NonexistentPID(t *testing.T) { + info := GetProcessInfo(99999999) + if info.Running { + t.Error("GetProcessInfo(99999999).Running = true, want false") + } +} + +// --------------------------------------------------------------------------- +// CurrentProcessInfo +// --------------------------------------------------------------------------- + +func TestCurrentProcessInfo(t *testing.T) { + info := CurrentProcessInfo() + if !info.Running { + t.Error("CurrentProcessInfo().Running = false, want true") + } + if info.PID != os.Getpid() { + t.Errorf("CurrentProcessInfo().PID = %d, want %d", info.PID, os.Getpid()) + } + if info.Executable == "" { + t.Error("CurrentProcessInfo().Executable is empty") + } + if info.Name == "" { + t.Error("CurrentProcessInfo().Name is empty") + } +} + +// --------------------------------------------------------------------------- +// ParentProcessInfo +// --------------------------------------------------------------------------- + +func TestParentProcessInfo(t *testing.T) { + info := ParentProcessInfo() + // Parent process should exist (the test runner). + if info.PID <= 0 { + t.Errorf("ParentProcessInfo().PID = %d, want > 0", info.PID) + } + // In most environments, the parent should be running. + // Skip assertion on Running since it depends on the test environment. +} + +// --------------------------------------------------------------------------- +// FindProcessByName +// --------------------------------------------------------------------------- + +func TestFindProcessByName_Empty(t *testing.T) { + results := FindProcessByName("") + if results == nil { + t.Error("FindProcessByName(\"\") returned nil, want empty slice") + } + if len(results) != 0 { + t.Errorf("FindProcessByName(\"\") returned %d results, want 0", len(results)) + } +} + +func TestFindProcessByName_CurrentProcess(t *testing.T) { + // Get current process name. + current := CurrentProcessInfo() + if current.Name == "" { + t.Skip("cannot determine current process name") + } + + results := FindProcessByName(current.Name) + if len(results) == 0 { + t.Errorf("FindProcessByName(%q) returned 0 results, want >= 1", current.Name) + } + + // Verify at least one result matches our PID. + found := false + for _, r := range results { + if r.PID == current.PID { + found = true + break + } + } + if !found { + t.Errorf("FindProcessByName(%q) did not find current process PID %d", current.Name, current.PID) + } +} + +func TestFindProcessByName_Nonexistent(t *testing.T) { + results := FindProcessByName("azdext-nonexistent-process-xyz") + if results == nil { + t.Error("FindProcessByName(nonexistent) returned nil, want empty slice") + } + if len(results) != 0 { + t.Errorf("FindProcessByName(nonexistent) returned %d results, want 0", len(results)) + } +} + +// --------------------------------------------------------------------------- +// ProcessEnvironment +// --------------------------------------------------------------------------- + +func TestGetProcessEnvironment(t *testing.T) { + env := GetProcessEnvironment() + + if env.PID != os.Getpid() { + t.Errorf("ProcessEnvironment.PID = %d, want %d", env.PID, os.Getpid()) + } + if env.PPID <= 0 { + t.Errorf("ProcessEnvironment.PPID = %d, want > 0", env.PPID) + } + if env.OS != runtime.GOOS { + t.Errorf("ProcessEnvironment.OS = %q, want %q", env.OS, runtime.GOOS) + } + if env.Arch != runtime.GOARCH { + t.Errorf("ProcessEnvironment.Arch = %q, want %q", env.Arch, runtime.GOARCH) + } + if env.NumCPU <= 0 { + t.Errorf("ProcessEnvironment.NumCPU = %d, want > 0", env.NumCPU) + } + if env.Executable == "" { + t.Error("ProcessEnvironment.Executable is empty") + } +} + +func TestProcessEnvironment_String(t *testing.T) { + env := GetProcessEnvironment() + s := env.String() + + if !strings.Contains(s, "pid=") { + t.Errorf("ProcessEnvironment.String() = %q, missing pid=", s) + } + if !strings.Contains(s, "os=") { + t.Errorf("ProcessEnvironment.String() = %q, missing os=", s) + } + if !strings.Contains(s, "arch=") { + t.Errorf("ProcessEnvironment.String() = %q, missing arch=", s) + } +} + +// --------------------------------------------------------------------------- +// extractBaseName +// --------------------------------------------------------------------------- + +func TestExtractBaseName(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"/usr/bin/bash", "bash"}, + {`C:\Windows\System32\cmd.exe`, "cmd"}, + {"simple", "simple"}, + {"program.exe", "program"}, + {"/path/to/my-tool", "my-tool"}, + {`C:\tools\mytool.exe`, "mytool"}, + {"", ""}, + } + for _, tc := range tests { + got := extractBaseName(tc.input) + if got != tc.want { + t.Errorf("extractBaseName(%q) = %q, want %q", tc.input, got, tc.want) + } + } +} diff --git a/cli/azd/pkg/azdext/process_windows.go b/cli/azd/pkg/azdext/process_windows.go new file mode 100644 index 00000000000..c3542d36a48 --- /dev/null +++ b/cli/azd/pkg/azdext/process_windows.go @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//go:build windows + +package azdext + +import ( + "strings" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +// isProcessRunningOS checks if a process is running on Windows using the +// Windows API. Opens the process with PROCESS_QUERY_LIMITED_INFORMATION and +// checks whether the exit code is STILL_ACTIVE (259). +func isProcessRunningOS(pid int) bool { + //nolint:gosec // G115: pid is validated in caller + handle, err := windows.OpenProcess( + windows.PROCESS_QUERY_LIMITED_INFORMATION, + false, + uint32(pid), + ) + if err != nil { + return false + } + defer windows.CloseHandle(handle) + + var exitCode uint32 + err = windows.GetExitCodeProcess(handle, &exitCode) + if err != nil { + return false + } + + // STILL_ACTIVE (259) means the process has not exited. + return exitCode == 259 +} + +// getProcessInfoOS retrieves process info on Windows using the Windows API. +func getProcessInfoOS(pid int) ProcessInfo { + info := ProcessInfo{PID: pid} + + //nolint:gosec // G115: pid is validated in caller + handle, err := windows.OpenProcess( + windows.PROCESS_QUERY_LIMITED_INFORMATION, + false, + uint32(pid), + ) + if err != nil { + return info + } + defer windows.CloseHandle(handle) + + // Check if process is still running. + var exitCode uint32 + if err := windows.GetExitCodeProcess(handle, &exitCode); err != nil { + return info + } + if exitCode != 259 { + return info // Process has exited. + } + + info.Running = true + + // Get executable path. + bufSize := uint32(windows.MAX_PATH) + buf := make([]uint16, bufSize) + if err := windows.QueryFullProcessImageName(handle, 0, &buf[0], &bufSize); err != nil { + // Try with a larger buffer. + bufSize = 32768 + buf = make([]uint16, bufSize) + if err := windows.QueryFullProcessImageName(handle, 0, &buf[0], &bufSize); err != nil { + return info + } + } + + info.Executable = syscall.UTF16ToString(buf[:bufSize]) + info.Name = extractBaseName(info.Executable) + + return info +} + +// findProcessByNameOS searches for processes by name on Windows using +// CreateToolhelp32Snapshot. +func findProcessByNameOS(name string) []ProcessInfo { + snapshot, err := windows.CreateToolhelp32Snapshot(windows.TH32CS_SNAPPROCESS, 0) + if err != nil { + return []ProcessInfo{} + } + defer windows.CloseHandle(snapshot) + + var entry windows.ProcessEntry32 + entry.Size = uint32(unsafe.Sizeof(entry)) + + if err := windows.Process32First(snapshot, &entry); err != nil { + return []ProcessInfo{} + } + + nameLower := strings.ToLower(name) + var results []ProcessInfo + + for { + exeName := syscall.UTF16ToString(entry.ExeFile[:]) + baseName := extractBaseName(exeName) + + if strings.EqualFold(baseName, nameLower) { + info := ProcessInfo{ + PID: int(entry.ProcessID), + Name: baseName, + Running: true, + } + + // Try to get full executable path. + //nolint:gosec // G115: PID comes from OS snapshot + handle, err := windows.OpenProcess( + windows.PROCESS_QUERY_LIMITED_INFORMATION, + false, + entry.ProcessID, + ) + if err == nil { + bufSize := uint32(windows.MAX_PATH) + buf := make([]uint16, bufSize) + if err := windows.QueryFullProcessImageName(handle, 0, &buf[0], &bufSize); err == nil { + info.Executable = syscall.UTF16ToString(buf[:bufSize]) + } + windows.CloseHandle(handle) + } + + results = append(results, info) + } + + if err := windows.Process32Next(snapshot, &entry); err != nil { + break + } + } + + if results == nil { + return []ProcessInfo{} + } + return results +} diff --git a/cli/azd/pkg/azdext/resilient_http_client.go b/cli/azd/pkg/azdext/resilient_http_client.go index 2916b885334..c88acb4ac04 100644 --- a/cli/azd/pkg/azdext/resilient_http_client.go +++ b/cli/azd/pkg/azdext/resilient_http_client.go @@ -9,6 +9,8 @@ import ( "fmt" "io" "math" + "math/rand/v2" + "net" "net/http" "strconv" "time" @@ -88,6 +90,10 @@ func (o *ResilientClientOptions) defaults() { // tokenProvider may be nil if the caller handles Authorization headers manually. // When non-nil, the client automatically injects a Bearer token using scopes // resolved from the request URL via the [ScopeDetector]. +// +// The client uses a safe transport that validates resolved IP addresses at +// connection time, mitigating DNS rebinding TOCTOU attacks where a hostname +// resolves to a safe IP at check time but to a private IP at connect time. func NewResilientClient(tokenProvider azcore.TokenCredential, opts *ResilientClientOptions) *ResilientClient { if opts == nil { opts = &ResilientClientOptions{} @@ -97,7 +103,7 @@ func NewResilientClient(tokenProvider azcore.TokenCredential, opts *ResilientCli transport := opts.Transport if transport == nil { - transport = http.DefaultTransport + transport = ssrfSafeTransport() } sd := opts.ScopeDetector @@ -107,8 +113,9 @@ func NewResilientClient(tokenProvider azcore.TokenCredential, opts *ResilientCli return &ResilientClient{ httpClient: &http.Client{ - Transport: transport, - Timeout: opts.Timeout, + Transport: transport, + Timeout: opts.Timeout, + CheckRedirect: SSRFSafeRedirect, }, tokenProvider: tokenProvider, scopeDetector: sd, @@ -116,6 +123,49 @@ func NewResilientClient(tokenProvider azcore.TokenCredential, opts *ResilientCli } } +// ssrfSafeTransport returns an [http.RoundTripper] that validates every +// resolved IP address at dial time against SSRF block lists. This closes the +// DNS-rebinding TOCTOU gap where a hostname passes the pre-request SSRF check +// (resolves to a public IP) but is re-resolved by the dialer to a private IP. +// +// The dialer uses the default SSRF guard (metadata + private networks blocked). +func ssrfSafeTransport() http.RoundTripper { + guard := DefaultSSRFGuard() + dialer := &net.Dialer{Timeout: 30 * time.Second} + + return &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, fmt.Errorf("azdext: invalid address %q: %w", addr, err) + } + + // Resolve and validate each IP before connecting. + ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, fmt.Errorf("azdext: DNS resolution failed for %s (fail-closed): %w", host, err) + } + + for _, ipAddr := range ips { + if reason, detail, blocked := ssrfCheckIP( + ipAddr.IP, host, guard.blockedCIDRs, guard.blockPrivate, + ); blocked { + return nil, fmt.Errorf( + "azdext: SSRF dial blocked: %s: %s (host=%s)", reason, detail, host) + } + } + + // Connect to the first resolved IP (already validated). + return dialer.DialContext(ctx, network, net.JoinHostPort(ips[0].IP.String(), port)) + }, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } +} + // Do executes an HTTP request with retry logic and optional bearer-token injection. // // body may be nil for requests without a body (GET, DELETE). @@ -127,6 +177,17 @@ func (rc *ResilientClient) Do(ctx context.Context, method, url string, body io.R return nil, errors.New("azdext.ResilientClient.Do: context must not be nil") } + // Validate body seekability upfront when retries are enabled. + // Fail fast rather than discovering the body is not seekable after the + // first attempt has already consumed it. + if body != nil && rc.opts.MaxRetries > 0 { + if _, ok := body.(io.ReadSeeker); !ok { + return nil, errors.New( + "azdext.ResilientClient.Do: request body does not implement io.ReadSeeker; " + + "retries require a seekable body (use bytes.NewReader or strings.NewReader)") + } + } + var lastErr error var retryAfterOverride time.Duration @@ -240,7 +301,8 @@ func (rc *ResilientClient) backoff(attempt int) time.Duration { delay = rc.opts.MaxDelay } - return delay + jitter := 0.5 + rand.Float64()*0.5 + return time.Duration(float64(delay) * jitter) } // isRetryable returns true for status codes that indicate a transient failure. diff --git a/cli/azd/pkg/azdext/resilient_http_client_test.go b/cli/azd/pkg/azdext/resilient_http_client_test.go index 4ffd67af103..7717402229a 100644 --- a/cli/azd/pkg/azdext/resilient_http_client_test.go +++ b/cli/azd/pkg/azdext/resilient_http_client_test.go @@ -520,19 +520,20 @@ func TestResilientClient_NonSeekableBodyRetryError(t *testing.T) { }) // io.NopCloser wrapping strings.NewReader is NOT an io.ReadSeeker. + // With upfront validation, the error is caught before any HTTP call. body := io.NopCloser(strings.NewReader("payload")) _, err := rc.Do(context.Background(), http.MethodPost, "https://example.com/api", body) if err == nil { - t.Fatal("expected error for non-seekable body on retry") + t.Fatal("expected error for non-seekable body with retries enabled") } if !strings.Contains(err.Error(), "io.ReadSeeker") { t.Errorf("error = %q, want mention of io.ReadSeeker", err.Error()) } - // Should have made exactly 1 attempt (first gets 503 → retry → fail on body check). - if attempts != 1 { - t.Errorf("attempts = %d, want 1 (fail before second attempt)", attempts) + // Should have made zero attempts — upfront check rejects before any HTTP call. + if attempts != 0 { + t.Errorf("attempts = %d, want 0 (fail fast before any request)", attempts) } } @@ -620,3 +621,93 @@ func TestResilientClient_RetryAfterCapped(t *testing.T) { t.Errorf("retryAfterFromResponse() = %v, want %v (capping happens in Do)", got, 999999*time.Second) } } + +func TestResilientClient_BackoffJitter(t *testing.T) { + t.Parallel() + + rc := NewResilientClient(nil, &ResilientClientOptions{ + InitialDelay: 100 * time.Millisecond, + MaxDelay: 10 * time.Second, + }) + + seen := make(map[time.Duration]bool) + for range 20 { + d := rc.backoff(1) + seen[d] = true + if d < 50*time.Millisecond || d >= 100*time.Millisecond { + t.Errorf("backoff(1) = %v, want in [50ms, 100ms)", d) + } + } + + if len(seen) < 2 { + t.Error("backoff jitter produced identical values across 20 calls") + } +} + +func TestResilientClient_NonSeekableBodyFailsFast(t *testing.T) { + t.Parallel() + + var attempts int + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + attempts++ + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("ok")), + Header: http.Header{}, + }, nil + }) + + rc := NewResilientClient(nil, &ResilientClientOptions{ + Transport: transport, + MaxRetries: 2, + InitialDelay: time.Millisecond, + }) + + body := io.NopCloser(strings.NewReader("payload")) + _, err := rc.Do(context.Background(), http.MethodPost, "https://example.com/api", body) + if err == nil { + t.Fatal("expected error for non-seekable body with retries enabled") + } + + if !strings.Contains(err.Error(), "io.ReadSeeker") { + t.Errorf("error = %q, want mention of io.ReadSeeker", err.Error()) + } + + if attempts != 0 { + t.Errorf("attempts = %d, want 0 (fail fast before any request)", attempts) + } +} + +func TestResilientClient_RetryAfterCappedInDo(t *testing.T) { + t.Parallel() + + var attempts int + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + attempts++ + h := http.Header{} + h.Set("retry-after", "999999") + return &http.Response{ + StatusCode: http.StatusTooManyRequests, + Body: io.NopCloser(strings.NewReader("throttled")), + Header: h, + }, nil + }) + + rc := NewResilientClient(nil, &ResilientClientOptions{ + Transport: transport, + MaxRetries: 1, + InitialDelay: time.Millisecond, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) + defer cancel() + + _, err := rc.Do(ctx, http.MethodGet, "https://example.com/api", nil) + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("expected context.DeadlineExceeded (proving cap was applied), got: %v", err) + } + + if attempts != 1 { + t.Errorf("attempts = %d, want 1", attempts) + } +} diff --git a/cli/azd/pkg/azdext/security_validation.go b/cli/azd/pkg/azdext/security_validation.go new file mode 100644 index 00000000000..5d5e5dd1c7e --- /dev/null +++ b/cli/azd/pkg/azdext/security_validation.go @@ -0,0 +1,249 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "fmt" + "os" + "regexp" + "strings" +) + +// --------------------------------------------------------------------------- +// Validation error type +// --------------------------------------------------------------------------- + +// ValidationError describes a failed input validation with structured context. +type ValidationError struct { + // Field is the logical name of the input being validated (e.g. "service_name"). + Field string + + // Value is the rejected input value. For security-sensitive inputs the value + // may be truncated or redacted by the caller before constructing the error. + Value string + + // Rule is a short machine-readable tag for the violated constraint + // (e.g. "format", "length", "characters"). + Rule string + + // Message is a human-readable explanation suitable for end-user display. + Message string +} + +func (e *ValidationError) Error() string { + return fmt.Sprintf("azdext.Validate: %s %q: %s (%s)", e.Field, e.Value, e.Message, e.Rule) +} + +// --------------------------------------------------------------------------- +// Service name validation +// --------------------------------------------------------------------------- + +// serviceNameRe matches DNS-safe service names: +// - starts with alphanumeric +// - contains only alphanumeric, '.', '_', '-' +// - 1–63 characters total (DNS label limit) +var serviceNameRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._-]{0,62}$`) + +// ValidateServiceName checks that name is a valid DNS-safe service identifier. +// +// Rules: +// - Must start with an alphanumeric character. +// - May contain alphanumeric characters, '.', '_', and '-'. +// - Must be 1–63 characters (DNS label length limit per RFC 1035). +// +// Returns a [*ValidationError] on failure. +func ValidateServiceName(name string) error { + if name == "" { + return &ValidationError{ + Field: "service_name", + Value: "", + Rule: "required", + Message: "service name must not be empty", + } + } + + if !serviceNameRe.MatchString(name) { + return &ValidationError{ + Field: "service_name", + Value: truncateValue(name, 64), + Rule: "format", + Message: "service name must start with alphanumeric and contain only [a-zA-Z0-9._-], max 63 chars", + } + } + + return nil +} + +// --------------------------------------------------------------------------- +// Hostname validation +// --------------------------------------------------------------------------- + +// hostnameRe matches RFC 952/1123 hostnames: +// - labels separated by '.' +// - each label: starts and ends with alphanumeric, may contain '-', 1–63 chars +// - total length <= 253 characters +var hostnameRe = regexp.MustCompile( + `^[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?` + + `(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$`, +) + +// ValidateHostname checks that hostname conforms to RFC 952/1123. +// +// Rules: +// - Each label must start and end with an alphanumeric character. +// - Labels may contain alphanumeric characters and '-'. +// - Each label is 1–63 characters. +// - Total hostname length is ≤ 253 characters. +// +// Returns a [*ValidationError] on failure. +func ValidateHostname(hostname string) error { + if hostname == "" { + return &ValidationError{ + Field: "hostname", + Value: "", + Rule: "required", + Message: "hostname must not be empty", + } + } + + if len(hostname) > 253 { + return &ValidationError{ + Field: "hostname", + Value: truncateValue(hostname, 64), + Rule: "length", + Message: "hostname must not exceed 253 characters", + } + } + + if !hostnameRe.MatchString(hostname) { + return &ValidationError{ + Field: "hostname", + Value: truncateValue(hostname, 64), + Rule: "format", + Message: "hostname must conform to RFC 952/1123 " + + "(labels: alphanumeric start/end, may contain '-', 1-63 chars each)", + } + } + + return nil +} + +// --------------------------------------------------------------------------- +// Script name validation +// --------------------------------------------------------------------------- + +// shellMetacharacters contains characters that have special meaning in common +// shells (bash, sh, zsh, cmd, PowerShell). A script name containing any of +// these is rejected to prevent command injection. +const shellMetacharacters = ";|&`$(){}[]<>!#~*?\"\\'%\n\r\x00" + +// ValidateScriptName checks that name does not contain shell metacharacters +// or path traversal sequences that could lead to command injection. +// +// Rejected patterns: +// - Shell metacharacters: ; | & ` $ ( ) { } [ ] < > ! # ~ * ? " ' \ % +// - Path traversal: ".." +// - Null bytes and newlines +// - Empty names +// +// Returns a [*ValidationError] on failure. +func ValidateScriptName(name string) error { + if name == "" { + return &ValidationError{ + Field: "script_name", + Value: "", + Rule: "required", + Message: "script name must not be empty", + } + } + + if strings.Contains(name, "..") { + return &ValidationError{ + Field: "script_name", + Value: truncateValue(name, 64), + Rule: "traversal", + Message: "script name must not contain path traversal sequences (..)", + } + } + + if idx := strings.IndexAny(name, shellMetacharacters); idx >= 0 { + return &ValidationError{ + Field: "script_name", + Value: truncateValue(name, 64), + Rule: "characters", + Message: fmt.Sprintf("script name contains forbidden shell metacharacter at position %d", idx), + } + } + + return nil +} + +// --------------------------------------------------------------------------- +// Container environment detection +// --------------------------------------------------------------------------- + +// containerEnvVars maps environment variables to the container runtime they indicate. +var containerEnvVars = map[string]string{ + "CODESPACES": "codespaces", + "KUBERNETES_SERVICE_HOST": "kubernetes", + "REMOTE_CONTAINERS": "devcontainer", + "REMOTE_CONTAINERS_IPC": "devcontainer", +} + +// IsContainerEnvironment reports whether the current process is running inside +// a container environment. It checks for: +// - GitHub Codespaces (CODESPACES env var) +// - Kubernetes (KUBERNETES_SERVICE_HOST env var) +// - VS Code Dev Containers (REMOTE_CONTAINERS / REMOTE_CONTAINERS_IPC env vars) +// - Docker (/.dockerenv file) +// +// The detection is best-effort and does not guarantee accuracy in all +// environments. It is intended for feature gating and diagnostics, not +// security decisions. +func IsContainerEnvironment() bool { + // Check well-known environment variables. + for envKey := range containerEnvVars { + if v := os.Getenv(envKey); v != "" { + return true + } + } + + // Check for Docker's marker file. + if _, err := os.Stat("/.dockerenv"); err == nil { + return true + } + + return false +} + +// ContainerRuntime returns the detected container runtime name, or an empty +// string if no container environment is detected. +// +// Possible return values: "codespaces", "kubernetes", "devcontainer", "docker", "". +func ContainerRuntime() string { + for envKey, runtime := range containerEnvVars { + if v := os.Getenv(envKey); v != "" { + return runtime + } + } + + if _, err := os.Stat("/.dockerenv"); err == nil { + return "docker" + } + + return "" +} + +// --------------------------------------------------------------------------- +// Internal helpers +// --------------------------------------------------------------------------- + +// truncateValue truncates s to maxLen characters for safe inclusion in error +// messages. If truncated, an ellipsis is appended. +func truncateValue(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} diff --git a/cli/azd/pkg/azdext/security_validation_test.go b/cli/azd/pkg/azdext/security_validation_test.go new file mode 100644 index 00000000000..cd4165e67dd --- /dev/null +++ b/cli/azd/pkg/azdext/security_validation_test.go @@ -0,0 +1,354 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "errors" + "os" + "strings" + "testing" +) + +// --------------------------------------------------------------------------- +// ValidateServiceName +// --------------------------------------------------------------------------- + +func TestValidateServiceName_Valid(t *testing.T) { + valid := []string{ + "web", + "my-service", + "my_service", + "my.service", + "A1", + "a", + "service-v2.0", + "a23456789012345678901234567890123456789012345678901234567890123", // 63 chars + } + for _, name := range valid { + if err := ValidateServiceName(name); err != nil { + t.Errorf("ValidateServiceName(%q) = %v, want nil", name, err) + } + } +} + +func TestValidateServiceName_Invalid(t *testing.T) { + tests := []struct { + name string + rule string + }{ + {"", "required"}, + {"-starts-with-dash", "format"}, + {".starts-with-dot", "format"}, + {"_starts-with-underscore", "format"}, + {"has space", "format"}, + {"has;semicolon", "format"}, + {"has|pipe", "format"}, + {"has$dollar", "format"}, + {"has/slash", "format"}, + {"has@at", "format"}, + // 64 chars (too long) + {"a234567890123456789012345678901234567890123456789012345678901234", "format"}, + } + for _, tc := range tests { + err := ValidateServiceName(tc.name) + if err == nil { + t.Errorf("ValidateServiceName(%q) = nil, want error", tc.name) + continue + } + var ve *ValidationError + if ok := isValidationError(err, &ve); !ok { + t.Errorf("ValidateServiceName(%q) returned %T, want *ValidationError", tc.name, err) + continue + } + if ve.Rule != tc.rule { + t.Errorf("ValidateServiceName(%q).Rule = %q, want %q", tc.name, ve.Rule, tc.rule) + } + if ve.Field != "service_name" { + t.Errorf("ValidateServiceName(%q).Field = %q, want %q", tc.name, ve.Field, "service_name") + } + } +} + +// --------------------------------------------------------------------------- +// ValidateHostname +// --------------------------------------------------------------------------- + +func TestValidateHostname_Valid(t *testing.T) { + valid := []string{ + "example.com", + "sub.example.com", + "a.b.c.d.example.com", + "my-host", + "a", + "1", + "192-168-1-1.nip.io", + "xn--nxasmq6b.example.com", // punycode + } + for _, h := range valid { + if err := ValidateHostname(h); err != nil { + t.Errorf("ValidateHostname(%q) = %v, want nil", h, err) + } + } +} + +func TestValidateHostname_Invalid(t *testing.T) { + tests := []struct { + hostname string + rule string + }{ + {"", "required"}, + {"-starts-with-dash.com", "format"}, + {"ends-with-dash-.com", "format"}, + {"has space.com", "format"}, + {"has_underscore.com", "format"}, + {"has..double-dot.com", "format"}, + {".starts-with-dot.com", "format"}, + {"has;semicolon.com", "format"}, + // 254 chars (too long) + {strings.Repeat("a", 254), "length"}, + } + for _, tc := range tests { + err := ValidateHostname(tc.hostname) + if err == nil { + t.Errorf("ValidateHostname(%q) = nil, want error", tc.hostname) + continue + } + var ve *ValidationError + if ok := isValidationError(err, &ve); !ok { + t.Errorf("ValidateHostname(%q) returned %T, want *ValidationError", tc.hostname, err) + continue + } + if ve.Rule != tc.rule { + t.Errorf("ValidateHostname(%q).Rule = %q, want %q", tc.hostname, ve.Rule, tc.rule) + } + } +} + +func TestValidateHostname_LabelLength(t *testing.T) { + // Each label max 63 chars. A 64-char label should fail. + longLabel := strings.Repeat("a", 64) + ".com" + if err := ValidateHostname(longLabel); err == nil { + t.Errorf("ValidateHostname with 64-char label = nil, want error") + } + + // 63-char label should succeed. + okLabel := strings.Repeat("a", 63) + ".com" + if err := ValidateHostname(okLabel); err != nil { + t.Errorf("ValidateHostname with 63-char label = %v, want nil", err) + } +} + +// --------------------------------------------------------------------------- +// ValidateScriptName +// --------------------------------------------------------------------------- + +func TestValidateScriptName_Valid(t *testing.T) { + valid := []string{ + "script.sh", + "my-script.py", + "build_project.ps1", + "run", + "deploy-v2.sh", + "test.cmd", + "start server", // spaces are OK in script names (not metacharacters) + } + for _, name := range valid { + if err := ValidateScriptName(name); err != nil { + t.Errorf("ValidateScriptName(%q) = %v, want nil", name, err) + } + } +} + +func TestValidateScriptName_ShellMetacharacters(t *testing.T) { + dangerous := []struct { + name string + desc string + }{ + {"script;rm -rf /", "semicolon command chaining"}, + {"script|cat /etc/passwd", "pipe"}, + {"script&background", "ampersand"}, + {"script`whoami`", "backtick command substitution"}, + {"script$(id)", "dollar-paren command substitution"}, + {"script > /dev/null", "output redirect"}, + {"script < /etc/passwd", "input redirect"}, + {"script\nrm -rf /", "newline injection"}, + {"script\x00null", "null byte"}, + {"script'quoted'", "single quote"}, + {"script\"quoted\"", "double quote"}, + {"script\\escaped", "backslash"}, + {"script!history", "exclamation/history expansion"}, + {"script#comment", "hash/comment"}, + {"script~home", "tilde expansion"}, + {"script*glob", "glob star"}, + {"script?glob", "glob question"}, + {"script%env", "percent"}, + {"script(sub)", "open paren"}, + {"script{brace}", "open brace"}, + {"script[bracket]", "open bracket"}, + } + for _, tc := range dangerous { + err := ValidateScriptName(tc.name) + if err == nil { + t.Errorf("ValidateScriptName(%q) = nil, want error (%s)", tc.name, tc.desc) + continue + } + var ve *ValidationError + if ok := isValidationError(err, &ve); !ok { + t.Errorf("ValidateScriptName(%q) returned %T, want *ValidationError", tc.name, err) + continue + } + if ve.Rule != "characters" { + t.Errorf("ValidateScriptName(%q).Rule = %q, want %q (%s)", tc.name, ve.Rule, "characters", tc.desc) + } + } +} + +func TestValidateScriptName_PathTraversal(t *testing.T) { + traversal := []string{ + "../etc/passwd", + "../../secret.sh", + "dir/../../../root.sh", + "..\\windows\\system32", + } + for _, name := range traversal { + err := ValidateScriptName(name) + if err == nil { + t.Errorf("ValidateScriptName(%q) = nil, want error", name) + continue + } + var ve *ValidationError + if ok := isValidationError(err, &ve); !ok { + t.Errorf("ValidateScriptName(%q) returned %T, want *ValidationError", name, err) + continue + } + if ve.Rule != "traversal" { + t.Errorf("ValidateScriptName(%q).Rule = %q, want %q", name, ve.Rule, "traversal") + } + } +} + +func TestValidateScriptName_Empty(t *testing.T) { + err := ValidateScriptName("") + if err == nil { + t.Error("ValidateScriptName(\"\") = nil, want error") + return + } + var ve *ValidationError + if ok := isValidationError(err, &ve); !ok { + t.Errorf("ValidateScriptName(\"\") returned %T, want *ValidationError", err) + return + } + if ve.Rule != "required" { + t.Errorf("ValidateScriptName(\"\").Rule = %q, want %q", ve.Rule, "required") + } +} + +// --------------------------------------------------------------------------- +// IsContainerEnvironment / ContainerRuntime +// --------------------------------------------------------------------------- + +func TestIsContainerEnvironment_EnvVars(t *testing.T) { + tests := []struct { + envKey string + runtime string + }{ + {"CODESPACES", "codespaces"}, + {"KUBERNETES_SERVICE_HOST", "kubernetes"}, + {"REMOTE_CONTAINERS", "devcontainer"}, + {"REMOTE_CONTAINERS_IPC", "devcontainer"}, + } + for _, tc := range tests { + t.Run(tc.envKey, func(t *testing.T) { + // Ensure the env var is clean before/after. + orig := os.Getenv(tc.envKey) + t.Setenv(tc.envKey, "true") + + if !IsContainerEnvironment() { + t.Errorf("IsContainerEnvironment() = false with %s set", tc.envKey) + } + + rt := ContainerRuntime() + if rt != tc.runtime { + t.Errorf("ContainerRuntime() = %q with %s set, want %q", rt, tc.envKey, tc.runtime) + } + + // Restore and verify negative case. + if orig == "" { + os.Unsetenv(tc.envKey) + } else { + os.Setenv(tc.envKey, orig) + } + }) + } +} + +func TestIsContainerEnvironment_NoContainerEnv(t *testing.T) { + // Clear all container-related env vars. + for envKey := range containerEnvVars { + if v := os.Getenv(envKey); v != "" { + t.Setenv(envKey, "") + os.Unsetenv(envKey) + } + } + + // In CI or local dev without Docker marker, this should return false. + // We can't guarantee /.dockerenv doesn't exist, but in typical test + // environments it won't. + runtime := ContainerRuntime() + // Only assert no env-var-based detection. Docker file detection is + // environment-dependent and not worth mocking here. + _ = runtime +} + +// --------------------------------------------------------------------------- +// ValidationError +// --------------------------------------------------------------------------- + +func TestValidationError_ErrorMessage(t *testing.T) { + err := &ValidationError{ + Field: "test_field", + Value: "bad-value", + Rule: "format", + Message: "value is invalid", + } + + msg := err.Error() + if !strings.Contains(msg, "test_field") { + t.Errorf("error message should contain field name, got: %s", msg) + } + if !strings.Contains(msg, "bad-value") { + t.Errorf("error message should contain value, got: %s", msg) + } + if !strings.Contains(msg, "format") { + t.Errorf("error message should contain rule, got: %s", msg) + } +} + +func TestTruncateValue(t *testing.T) { + tests := []struct { + input string + maxLen int + want string + }{ + {"short", 10, "short"}, + {"exactly10!", 10, "exactly10!"}, + {"this is way too long", 10, "this is wa..."}, + {"", 5, ""}, + } + for _, tc := range tests { + got := truncateValue(tc.input, tc.maxLen) + if got != tc.want { + t.Errorf("truncateValue(%q, %d) = %q, want %q", tc.input, tc.maxLen, got, tc.want) + } + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +// isValidationError is a type assertion helper for testing. +func isValidationError(err error, target **ValidationError) bool { + return errors.As(err, target) +} diff --git a/cli/azd/pkg/azdext/shell.go b/cli/azd/pkg/azdext/shell.go new file mode 100644 index 00000000000..18ddce69bed --- /dev/null +++ b/cli/azd/pkg/azdext/shell.go @@ -0,0 +1,271 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "errors" + "fmt" + "os" + "os/exec" + "runtime" + "strings" +) + +// --------------------------------------------------------------------------- +// P3-1: Shell detection and execution +// --------------------------------------------------------------------------- + +// ShellType represents a detected shell environment. +type ShellType string + +const ( + // ShellTypeBash is the Bourne Again Shell. + ShellTypeBash ShellType = "bash" + // ShellTypeSh is the POSIX shell. + ShellTypeSh ShellType = "sh" + // ShellTypeZsh is the Z Shell. + ShellTypeZsh ShellType = "zsh" + // ShellTypeFish is the Fish shell. + ShellTypeFish ShellType = "fish" + // ShellTypePowerShell is PowerShell (pwsh/powershell.exe). + ShellTypePowerShell ShellType = "powershell" + // ShellTypeCmd is Windows cmd.exe. + ShellTypeCmd ShellType = "cmd" + // ShellTypeUnknown indicates the shell could not be determined. + ShellTypeUnknown ShellType = "" +) + +// String returns the string representation of the shell type. +func (s ShellType) String() string { + if s == ShellTypeUnknown { + return "unknown" + } + return string(s) +} + +// ShellInfo contains information about the detected shell. +type ShellInfo struct { + // Type is the detected shell type. + Type ShellType + // Path is the filesystem path to the shell executable, if known. + Path string + // Source describes how the shell was detected. + Source string +} + +// DetectShell identifies the current shell environment. +// +// Detection strategy (in order): +// 1. SHELL environment variable (Unix) — most reliable on macOS/Linux. +// 2. ComSpec environment variable (Windows) — standard Windows shell path. +// 3. PSModulePath environment variable — indicates PowerShell on any platform. +// 4. Platform default fallback (sh on Unix, cmd on Windows). +// +// Platform behavior: +// - Windows: Detects cmd.exe (default), PowerShell, or WSL shells. +// - macOS/Linux: Detects from $SHELL (bash, zsh, fish, sh). +// - If $SHELL is unset, falls back to platform default. +// +// DetectShell never returns an error. If detection fails, Type is [ShellTypeUnknown]. +func DetectShell() ShellInfo { + // Strategy 1: $SHELL (Unix convention, also set in some Windows terminals). + if shellEnv := os.Getenv("SHELL"); shellEnv != "" { + st := shellTypeFromPath(shellEnv) + if st != ShellTypeUnknown { + return ShellInfo{Type: st, Path: shellEnv, Source: "SHELL"} + } + } + + // Strategy 2: PSModulePath indicates PowerShell is the active shell. + if psPath := os.Getenv("PSModulePath"); psPath != "" { + // Try to find pwsh/powershell on PATH. + if p, err := exec.LookPath("pwsh"); err == nil { + return ShellInfo{Type: ShellTypePowerShell, Path: p, Source: "PSModulePath"} + } + if p, err := exec.LookPath("powershell"); err == nil { + return ShellInfo{Type: ShellTypePowerShell, Path: p, Source: "PSModulePath"} + } + return ShellInfo{Type: ShellTypePowerShell, Path: "", Source: "PSModulePath"} + } + + // Strategy 3: ComSpec (Windows). + if comspec := os.Getenv("ComSpec"); comspec != "" { + st := shellTypeFromPath(comspec) + if st != ShellTypeUnknown { + return ShellInfo{Type: st, Path: comspec, Source: "ComSpec"} + } + // ComSpec is set but not a recognized shell; assume cmd. + return ShellInfo{Type: ShellTypeCmd, Path: comspec, Source: "ComSpec"} + } + + // Strategy 4: Platform default. + if runtime.GOOS == "windows" { + return ShellInfo{Type: ShellTypeCmd, Path: "", Source: "platform-default"} + } + return ShellInfo{Type: ShellTypeSh, Path: "/bin/sh", Source: "platform-default"} +} + +// ShellCommand creates an [exec.Cmd] that executes script through the +// appropriate shell for the current platform. +// +// Platform behavior: +// - Windows cmd: cmd.exe /C