diff --git a/cli/azd/pkg/azdext/config_helper.go b/cli/azd/pkg/azdext/config_helper.go new file mode 100644 index 00000000000..60917274fb3 --- /dev/null +++ b/cli/azd/pkg/azdext/config_helper.go @@ -0,0 +1,442 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "regexp" + "strings" +) + +// ConfigHelper provides typed, ergonomic access to azd configuration through +// the gRPC UserConfig and Environment services. It eliminates the boilerplate +// of raw gRPC calls and JSON marshaling that extension authors otherwise need. +// +// Configuration sources (in merge priority, lowest to highest): +// 1. User config (global azd config) — via UserConfigService +// 2. Environment config (per-env) — via EnvironmentService +// +// Usage: +// +// ch := azdext.NewConfigHelper(client) +// port, err := ch.GetUserString(ctx, "extensions.myext.port") +// var cfg MyConfig +// err = ch.GetUserJSON(ctx, "extensions.myext", &cfg) +type ConfigHelper struct { + client *AzdClient +} + +// NewConfigHelper creates a [ConfigHelper] for the given AZD client. +func NewConfigHelper(client *AzdClient) (*ConfigHelper, error) { + if client == nil { + return nil, errors.New("azdext.NewConfigHelper: client must not be nil") + } + + return &ConfigHelper{client: client}, nil +} + +// --- User Config (global) --- + +// GetUserString retrieves a string value from the global user config at the +// given dot-separated path. Returns ("", false, nil) when the path does not +// exist, and ("", false, err) on gRPC errors. +func (ch *ConfigHelper) GetUserString(ctx context.Context, path string) (string, bool, error) { + if err := validatePath(path); err != nil { + return "", false, err + } + + resp, err := ch.client.UserConfig().GetString(ctx, &GetUserConfigStringRequest{Path: path}) + if err != nil { + return "", false, fmt.Errorf("azdext.ConfigHelper.GetUserString: gRPC call failed for path %q: %w", path, err) + } + + return resp.GetValue(), resp.GetFound(), nil +} + +// GetUserJSON retrieves a value from the global user config and unmarshals it +// into out. Returns (false, nil) when the path does not exist. +func (ch *ConfigHelper) GetUserJSON(ctx context.Context, path string, out any) (bool, error) { + if err := validatePath(path); err != nil { + return false, err + } + + if out == nil { + return false, errors.New("azdext.ConfigHelper.GetUserJSON: out must not be nil") + } + + resp, err := ch.client.UserConfig().Get(ctx, &GetUserConfigRequest{Path: path}) + if err != nil { + return false, fmt.Errorf("azdext.ConfigHelper.GetUserJSON: gRPC call failed for path %q: %w", path, err) + } + + if !resp.GetFound() { + return false, nil + } + + data := resp.GetValue() + if len(data) == 0 { + return false, nil + } + + if err := json.Unmarshal(data, out); err != nil { + return true, &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("failed to unmarshal config at path %q: %w", path, err), + } + } + + return true, nil +} + +// SetUserJSON marshals value as JSON and writes it to the global user config +// at the given path. +func (ch *ConfigHelper) SetUserJSON(ctx context.Context, path string, value any) error { + if err := validatePath(path); err != nil { + return err + } + + if value == nil { + return errors.New("azdext.ConfigHelper.SetUserJSON: value must not be nil") + } + + data, err := json.Marshal(value) + if err != nil { + return &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("failed to marshal value for path %q: %w", path, err), + } + } + + _, err = ch.client.UserConfig().Set(ctx, &SetUserConfigRequest{ + Path: path, + Value: data, + }) + if err != nil { + return fmt.Errorf("azdext.ConfigHelper.SetUserJSON: gRPC call failed for path %q: %w", path, err) + } + + return nil +} + +// UnsetUser removes a value from the global user config. +func (ch *ConfigHelper) UnsetUser(ctx context.Context, path string) error { + if err := validatePath(path); err != nil { + return err + } + + _, err := ch.client.UserConfig().Unset(ctx, &UnsetUserConfigRequest{Path: path}) + if err != nil { + return fmt.Errorf("azdext.ConfigHelper.UnsetUser: gRPC call failed for path %q: %w", path, err) + } + + return nil +} + +// --- Environment Config (per-environment) --- + +// GetEnvString retrieves a string config value from the current environment. +// Returns ("", false, nil) when the path does not exist. +func (ch *ConfigHelper) GetEnvString(ctx context.Context, path string) (string, bool, error) { + if err := validatePath(path); err != nil { + return "", false, err + } + + resp, err := ch.client.Environment().GetConfigString(ctx, &GetConfigStringRequest{Path: path}) + if err != nil { + return "", false, fmt.Errorf("azdext.ConfigHelper.GetEnvString: gRPC call failed for path %q: %w", path, err) + } + + return resp.GetValue(), resp.GetFound(), nil +} + +// GetEnvJSON retrieves a value from the current environment's config and +// unmarshals it into out. Returns (false, nil) when the path does not exist. +func (ch *ConfigHelper) GetEnvJSON(ctx context.Context, path string, out any) (bool, error) { + if err := validatePath(path); err != nil { + return false, err + } + + if out == nil { + return false, errors.New("azdext.ConfigHelper.GetEnvJSON: out must not be nil") + } + + resp, err := ch.client.Environment().GetConfig(ctx, &GetConfigRequest{Path: path}) + if err != nil { + return false, fmt.Errorf("azdext.ConfigHelper.GetEnvJSON: gRPC call failed for path %q: %w", path, err) + } + + if !resp.GetFound() { + return false, nil + } + + data := resp.GetValue() + if len(data) == 0 { + return false, nil + } + + if err := json.Unmarshal(data, out); err != nil { + return true, &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("failed to unmarshal env config at path %q: %w", path, err), + } + } + + return true, nil +} + +// SetEnvJSON marshals value as JSON and writes it to the current environment's config. +func (ch *ConfigHelper) SetEnvJSON(ctx context.Context, path string, value any) error { + if err := validatePath(path); err != nil { + return err + } + + if value == nil { + return errors.New("azdext.ConfigHelper.SetEnvJSON: value must not be nil") + } + + data, err := json.Marshal(value) + if err != nil { + return &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("failed to marshal value for env config path %q: %w", path, err), + } + } + + _, err = ch.client.Environment().SetConfig(ctx, &SetConfigRequest{ + Path: path, + Value: data, + }) + if err != nil { + return fmt.Errorf("azdext.ConfigHelper.SetEnvJSON: gRPC call failed for path %q: %w", path, err) + } + + return nil +} + +// UnsetEnv removes a value from the current environment's config. +func (ch *ConfigHelper) UnsetEnv(ctx context.Context, path string) error { + if err := validatePath(path); err != nil { + return err + } + + _, err := ch.client.Environment().UnsetConfig(ctx, &UnsetConfigRequest{Path: path}) + if err != nil { + return fmt.Errorf("azdext.ConfigHelper.UnsetEnv: gRPC call failed for path %q: %w", path, err) + } + + return nil +} + +// --- Merge --- + +// MergeJSON performs a shallow merge of override into base, returning a new map. +// Both inputs must be JSON-compatible maps (map[string]any). Keys in override +// take precedence over keys in base. +// +// This is NOT a deep merge — nested maps are replaced entirely by the override +// value. For predictable extension config behavior, keep config structures flat +// or use explicit path-based Set operations for nested values. +func MergeJSON(base, override map[string]any) map[string]any { + merged := make(map[string]any, len(base)+len(override)) + + for k, v := range base { + merged[k] = v + } + + for k, v := range override { + merged[k] = v + } + + return merged +} + +// deepMergeMaxDepth is the maximum recursion depth for [DeepMergeJSON]. +// This prevents stack overflow from deeply nested or adversarial JSON +// structures. 32 levels is far deeper than any legitimate config hierarchy. +const deepMergeMaxDepth = 32 + +// DeepMergeJSON performs a recursive merge of override into base. +// When both base and override have a map value for the same key, those maps +// are merged recursively. Otherwise the override value replaces the base value. +// +// Recursion is bounded to [deepMergeMaxDepth] levels to prevent stack overflow +// from deeply nested or adversarial inputs. Beyond the limit, the override +// value replaces the base value (merge degrades to shallow at that level). +func DeepMergeJSON(base, override map[string]any) map[string]any { + return deepMergeJSON(base, override, 0) +} + +func deepMergeJSON(base, override map[string]any, depth int) map[string]any { + merged := make(map[string]any, len(base)+len(override)) + + for k, v := range base { + merged[k] = v + } + + for k, v := range override { + baseVal, exists := merged[k] + if !exists { + merged[k] = v + continue + } + + baseMap, baseIsMap := baseVal.(map[string]any) + overMap, overIsMap := v.(map[string]any) + + if baseIsMap && overIsMap && depth < deepMergeMaxDepth { + merged[k] = deepMergeJSON(baseMap, overMap, depth+1) + } else { + merged[k] = v + } + } + + return merged +} + +// --- Validation --- + +// ConfigValidator defines a function that validates a config value. +// It returns nil if valid, or an error describing the validation failure. +type ConfigValidator func(value any) error + +// ValidateConfig unmarshals the raw JSON data and runs all supplied validators. +// Returns the first validation error encountered, wrapped in a [*ConfigError]. +func ValidateConfig(path string, data []byte, validators ...ConfigValidator) error { + if len(data) == 0 { + return &ConfigError{ + Path: path, + Reason: ConfigReasonMissing, + Err: fmt.Errorf("config at path %q is empty", path), + } + } + + var value any + if err := json.Unmarshal(data, &value); err != nil { + return &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("config at path %q is not valid JSON: %w", path, err), + } + } + + for _, v := range validators { + if err := v(value); err != nil { + return &ConfigError{ + Path: path, + Reason: ConfigReasonValidationFailed, + Err: fmt.Errorf("config validation failed at path %q: %w", path, err), + } + } + } + + return nil +} + +// RequiredKeys returns a [ConfigValidator] that checks for the presence of +// the specified keys in a map value. +func RequiredKeys(keys ...string) ConfigValidator { + return func(value any) error { + m, ok := value.(map[string]any) + if !ok { + return fmt.Errorf("expected object, got %T", value) + } + + for _, key := range keys { + if _, exists := m[key]; !exists { + return fmt.Errorf("required key %q is missing", key) + } + } + + return nil + } +} + +// --- Error types --- + +// ConfigReason classifies the cause of a [ConfigError]. +type ConfigReason int + +const ( + // ConfigReasonMissing indicates the config path does not exist or is empty. + ConfigReasonMissing ConfigReason = iota + + // ConfigReasonInvalidFormat indicates the config value is not valid JSON + // or cannot be unmarshaled into the target type. + ConfigReasonInvalidFormat + + // ConfigReasonValidationFailed indicates a validator rejected the config value. + ConfigReasonValidationFailed +) + +// String returns a human-readable label. +func (r ConfigReason) String() string { + switch r { + case ConfigReasonMissing: + return "missing" + case ConfigReasonInvalidFormat: + return "invalid_format" + case ConfigReasonValidationFailed: + return "validation_failed" + default: + return "unknown" + } +} + +// ConfigError is returned by [ConfigHelper] methods on domain-level failures. +type ConfigError struct { + // Path is the config path that was being accessed. + Path string + + // Reason classifies the failure. + Reason ConfigReason + + // Err is the underlying error. + Err error +} + +func (e *ConfigError) Error() string { + return fmt.Sprintf("azdext.ConfigHelper: %s (path=%s): %v", e.Reason, e.Path, e.Err) +} + +func (e *ConfigError) Unwrap() error { + return e.Err +} + +var configSegmentRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]{0,62}$`) + +func validatePath(path string) error { + if path == "" { + return errors.New("azdext.ConfigHelper: config path must not be empty") + } + if strings.HasPrefix(path, ".") || strings.HasSuffix(path, ".") || strings.Contains(path, "..") { + return errors.New( + "azdext.ConfigHelper: config path must not have empty segments " + + "(no leading/trailing dots or consecutive dots)", + ) + } + for _, seg := range strings.Split(path, ".") { + if !configSegmentRe.MatchString(seg) { + return fmt.Errorf( + "azdext.ConfigHelper: config path segment %q must start with alphanumeric "+ + "and contain only [a-zA-Z0-9_-], max 63 chars", + truncateConfigValue(seg, 64), + ) + } + } + return nil +} + +func truncateConfigValue(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} diff --git a/cli/azd/pkg/azdext/config_helper_test.go b/cli/azd/pkg/azdext/config_helper_test.go new file mode 100644 index 00000000000..c68464f16bc --- /dev/null +++ b/cli/azd/pkg/azdext/config_helper_test.go @@ -0,0 +1,967 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "google.golang.org/grpc" +) + +// --- Stub UserConfigService --- + +type stubUserConfigService struct { + getResp *GetUserConfigResponse + getStringResp *GetUserConfigStringResponse + getSectionErr error + getErr error + getStringErr error + setErr error + unsetErr error +} + +func (s *stubUserConfigService) Get( + _ context.Context, _ *GetUserConfigRequest, _ ...grpc.CallOption, +) (*GetUserConfigResponse, error) { + return s.getResp, s.getErr +} + +func (s *stubUserConfigService) GetString( + _ context.Context, _ *GetUserConfigStringRequest, _ ...grpc.CallOption, +) (*GetUserConfigStringResponse, error) { + return s.getStringResp, s.getStringErr +} + +func (s *stubUserConfigService) GetSection( + _ context.Context, _ *GetUserConfigSectionRequest, _ ...grpc.CallOption, +) (*GetUserConfigSectionResponse, error) { + return nil, s.getSectionErr +} + +func (s *stubUserConfigService) Set( + _ context.Context, _ *SetUserConfigRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return &EmptyResponse{}, s.setErr +} + +func (s *stubUserConfigService) Unset( + _ context.Context, _ *UnsetUserConfigRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return &EmptyResponse{}, s.unsetErr +} + +// --- Stub EnvironmentService --- + +type stubEnvironmentService struct { + getConfigResp *GetConfigResponse + getConfigStringResp *GetConfigStringResponse + getConfigErr error + getConfigStringErr error + setConfigErr error + unsetConfigErr error +} + +func (s *stubEnvironmentService) GetCurrent( + _ context.Context, _ *EmptyRequest, _ ...grpc.CallOption, +) (*EnvironmentResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) List( + _ context.Context, _ *EmptyRequest, _ ...grpc.CallOption, +) (*EnvironmentListResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) Get( + _ context.Context, _ *GetEnvironmentRequest, _ ...grpc.CallOption, +) (*EnvironmentResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) Select( + _ context.Context, _ *SelectEnvironmentRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) GetValues( + _ context.Context, _ *GetEnvironmentRequest, _ ...grpc.CallOption, +) (*KeyValueListResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) GetValue( + _ context.Context, _ *GetEnvRequest, _ ...grpc.CallOption, +) (*KeyValueResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) SetValue( + _ context.Context, _ *SetEnvRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) GetConfig( + _ context.Context, _ *GetConfigRequest, _ ...grpc.CallOption, +) (*GetConfigResponse, error) { + return s.getConfigResp, s.getConfigErr +} + +func (s *stubEnvironmentService) GetConfigString( + _ context.Context, _ *GetConfigStringRequest, _ ...grpc.CallOption, +) (*GetConfigStringResponse, error) { + return s.getConfigStringResp, s.getConfigStringErr +} + +func (s *stubEnvironmentService) GetConfigSection( + _ context.Context, _ *GetConfigSectionRequest, _ ...grpc.CallOption, +) (*GetConfigSectionResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) SetConfig( + _ context.Context, _ *SetConfigRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return &EmptyResponse{}, s.setConfigErr +} + +func (s *stubEnvironmentService) UnsetConfig( + _ context.Context, _ *UnsetConfigRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return &EmptyResponse{}, s.unsetConfigErr +} + +// --- NewConfigHelper --- + +func TestNewConfigHelper_NilClient(t *testing.T) { + t.Parallel() + + _, err := NewConfigHelper(nil) + if err == nil { + t.Fatal("expected error for nil client") + } +} + +func TestNewConfigHelper_Success(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, err := NewConfigHelper(client) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if ch == nil { + t.Fatal("expected non-nil ConfigHelper") + } +} + +// --- GetUserString --- + +func TestGetUserString_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + _, _, err := ch.GetUserString(context.Background(), "") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +func TestGetUserString_Found(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getStringResp: &GetUserConfigStringResponse{Value: "8080", Found: true}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + val, found, err := ch.GetUserString(context.Background(), "extensions.myext.port") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !found { + t.Error("expected found = true") + } + + if val != "8080" { + t.Errorf("value = %q, want %q", val, "8080") + } +} + +func TestGetUserString_NotFound(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getStringResp: &GetUserConfigStringResponse{Value: "", Found: false}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + val, found, err := ch.GetUserString(context.Background(), "nonexistent.path") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if found { + t.Error("expected found = false") + } + + if val != "" { + t.Errorf("value = %q, want empty", val) + } +} + +func TestGetUserString_GRPCError(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getStringErr: errors.New("grpc unavailable"), + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + _, _, err := ch.GetUserString(context.Background(), "some.path") + if err == nil { + t.Fatal("expected error for gRPC failure") + } +} + +// --- GetUserJSON --- + +func TestGetUserJSON_Found(t *testing.T) { + t.Parallel() + + type myConfig struct { + Port int `json:"port"` + Host string `json:"host"` + } + + data, _ := json.Marshal(myConfig{Port: 3000, Host: "localhost"}) + stub := &stubUserConfigService{ + getResp: &GetUserConfigResponse{Value: data, Found: true}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg myConfig + found, err := ch.GetUserJSON(context.Background(), "extensions.myext", &cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !found { + t.Error("expected found = true") + } + + if cfg.Port != 3000 { + t.Errorf("Port = %d, want 3000", cfg.Port) + } + + if cfg.Host != "localhost" { + t.Errorf("Host = %q, want %q", cfg.Host, "localhost") + } +} + +func TestGetUserJSON_NotFound(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getResp: &GetUserConfigResponse{Value: nil, Found: false}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg map[string]any + found, err := ch.GetUserJSON(context.Background(), "nonexistent", &cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if found { + t.Error("expected found = false") + } +} + +func TestGetUserJSON_NilOut(t *testing.T) { + t.Parallel() + + client := &AzdClient{userConfigClient: &stubUserConfigService{}} + ch, _ := NewConfigHelper(client) + + _, err := ch.GetUserJSON(context.Background(), "some.path", nil) + if err == nil { + t.Fatal("expected error for nil out parameter") + } +} + +func TestGetUserJSON_InvalidJSON(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getResp: &GetUserConfigResponse{Value: []byte("not json"), Found: true}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg map[string]any + _, err := ch.GetUserJSON(context.Background(), "bad.json", &cfg) + if err == nil { + t.Fatal("expected error for invalid JSON") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonInvalidFormat { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonInvalidFormat) + } +} + +func TestGetUserJSON_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + var cfg map[string]any + _, err := ch.GetUserJSON(context.Background(), "", &cfg) + if err == nil { + t.Fatal("expected error for empty path") + } +} + +// --- SetUserJSON --- + +func TestSetUserJSON_Success(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{} + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.SetUserJSON(context.Background(), "extensions.myext.port", 3000) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestSetUserJSON_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + err := ch.SetUserJSON(context.Background(), "", "value") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +func TestSetUserJSON_NilValue(t *testing.T) { + t.Parallel() + + client := &AzdClient{userConfigClient: &stubUserConfigService{}} + ch, _ := NewConfigHelper(client) + + err := ch.SetUserJSON(context.Background(), "some.path", nil) + if err == nil { + t.Fatal("expected error for nil value") + } +} + +func TestSetUserJSON_GRPCError(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + setErr: errors.New("grpc write error"), + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.SetUserJSON(context.Background(), "some.path", "value") + if err == nil { + t.Fatal("expected error for gRPC failure") + } +} + +func TestSetUserJSON_UnmarshalableValue(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{} + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + // Channels cannot be marshaled to JSON + err := ch.SetUserJSON(context.Background(), "some.path", make(chan int)) + if err == nil { + t.Fatal("expected error for unmarshalable value") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonInvalidFormat { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonInvalidFormat) + } +} + +// --- UnsetUser --- + +func TestUnsetUser_Success(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{} + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.UnsetUser(context.Background(), "some.path") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestUnsetUser_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + err := ch.UnsetUser(context.Background(), "") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +// --- GetEnvString --- + +func TestGetEnvString_Found(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{ + getConfigStringResp: &GetConfigStringResponse{Value: "prod", Found: true}, + } + + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + val, found, err := ch.GetEnvString(context.Background(), "extensions.myext.mode") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !found { + t.Error("expected found = true") + } + + if val != "prod" { + t.Errorf("value = %q, want %q", val, "prod") + } +} + +func TestGetEnvString_NotFound(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{ + getConfigStringResp: &GetConfigStringResponse{Value: "", Found: false}, + } + + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + _, found, err := ch.GetEnvString(context.Background(), "nonexistent") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if found { + t.Error("expected found = false") + } +} + +func TestGetEnvString_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + _, _, err := ch.GetEnvString(context.Background(), "") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +// --- GetEnvJSON --- + +func TestGetEnvJSON_Found(t *testing.T) { + t.Parallel() + + type envConfig struct { + Debug bool `json:"debug"` + } + + data, _ := json.Marshal(envConfig{Debug: true}) + stub := &stubEnvironmentService{ + getConfigResp: &GetConfigResponse{Value: data, Found: true}, + } + + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg envConfig + found, err := ch.GetEnvJSON(context.Background(), "extensions.myext", &cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !found { + t.Error("expected found = true") + } + + if !cfg.Debug { + t.Error("expected Debug = true") + } +} + +func TestGetEnvJSON_NotFound(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{ + getConfigResp: &GetConfigResponse{Value: nil, Found: false}, + } + + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg map[string]any + found, err := ch.GetEnvJSON(context.Background(), "nonexistent", &cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if found { + t.Error("expected found = false") + } +} + +func TestGetEnvJSON_NilOut(t *testing.T) { + t.Parallel() + + client := &AzdClient{environmentClient: &stubEnvironmentService{}} + ch, _ := NewConfigHelper(client) + + _, err := ch.GetEnvJSON(context.Background(), "some.path", nil) + if err == nil { + t.Fatal("expected error for nil out parameter") + } +} + +// --- SetEnvJSON --- + +func TestSetEnvJSON_Success(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{} + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.SetEnvJSON(context.Background(), "extensions.myext.mode", "prod") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestSetEnvJSON_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + err := ch.SetEnvJSON(context.Background(), "", "value") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +func TestSetEnvJSON_NilValue(t *testing.T) { + t.Parallel() + + client := &AzdClient{environmentClient: &stubEnvironmentService{}} + ch, _ := NewConfigHelper(client) + + err := ch.SetEnvJSON(context.Background(), "some.path", nil) + if err == nil { + t.Fatal("expected error for nil value") + } +} + +// --- UnsetEnv --- + +func TestUnsetEnv_Success(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{} + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.UnsetEnv(context.Background(), "some.path") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestUnsetEnv_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + err := ch.UnsetEnv(context.Background(), "") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +// --- MergeJSON --- + +func TestMergeJSON_Basic(t *testing.T) { + t.Parallel() + + base := map[string]any{"a": 1, "b": 2} + override := map[string]any{"b": 3, "c": 4} + + result := MergeJSON(base, override) + + if result["a"] != 1 { + t.Errorf("a = %v, want 1", result["a"]) + } + + if result["b"] != 3 { + t.Errorf("b = %v, want 3 (override wins)", result["b"]) + } + + if result["c"] != 4 { + t.Errorf("c = %v, want 4", result["c"]) + } +} + +func TestMergeJSON_EmptyBase(t *testing.T) { + t.Parallel() + + result := MergeJSON(nil, map[string]any{"x": "y"}) + + if result["x"] != "y" { + t.Errorf("x = %v, want y", result["x"]) + } +} + +func TestMergeJSON_EmptyOverride(t *testing.T) { + t.Parallel() + + result := MergeJSON(map[string]any{"x": "y"}, nil) + + if result["x"] != "y" { + t.Errorf("x = %v, want y", result["x"]) + } +} + +func TestMergeJSON_BothEmpty(t *testing.T) { + t.Parallel() + + result := MergeJSON(nil, nil) + + if len(result) != 0 { + t.Errorf("len(result) = %d, want 0", len(result)) + } +} + +func TestMergeJSON_DoesNotMutateInputs(t *testing.T) { + t.Parallel() + + base := map[string]any{"a": 1} + override := map[string]any{"b": 2} + + _ = MergeJSON(base, override) + + if _, ok := base["b"]; ok { + t.Error("MergeJSON mutated base map") + } + + if _, ok := override["a"]; ok { + t.Error("MergeJSON mutated override map") + } +} + +// --- DeepMergeJSON --- + +func TestDeepMergeJSON_RecursiveMerge(t *testing.T) { + t.Parallel() + + base := map[string]any{ + "server": map[string]any{ + "host": "localhost", + "port": 3000, + }, + "debug": false, + } + + override := map[string]any{ + "server": map[string]any{ + "port": 8080, + "tls": true, + }, + "version": "1.0", + } + + result := DeepMergeJSON(base, override) + + server, ok := result["server"].(map[string]any) + if !ok { + t.Fatal("server should be a map") + } + + if server["host"] != "localhost" { + t.Errorf("server.host = %v, want localhost", server["host"]) + } + + if server["port"] != 8080 { + t.Errorf("server.port = %v, want 8080 (override wins)", server["port"]) + } + + if server["tls"] != true { + t.Errorf("server.tls = %v, want true", server["tls"]) + } + + if result["debug"] != false { + t.Errorf("debug = %v, want false", result["debug"]) + } + + if result["version"] != "1.0" { + t.Errorf("version = %v, want 1.0", result["version"]) + } +} + +func TestDeepMergeJSON_OverrideReplacesNonMap(t *testing.T) { + t.Parallel() + + base := map[string]any{"x": "string-value"} + override := map[string]any{"x": map[string]any{"nested": true}} + + result := DeepMergeJSON(base, override) + + nested, ok := result["x"].(map[string]any) + if !ok { + t.Fatal("override should replace string with map") + } + + if nested["nested"] != true { + t.Errorf("x.nested = %v, want true", nested["nested"]) + } +} + +func TestDeepMergeJSON_DoesNotMutateInputs(t *testing.T) { + t.Parallel() + + base := map[string]any{"a": map[string]any{"x": 1}} + override := map[string]any{"a": map[string]any{"y": 2}} + + _ = DeepMergeJSON(base, override) + + baseA := base["a"].(map[string]any) + if _, ok := baseA["y"]; ok { + t.Error("DeepMergeJSON mutated base nested map") + } +} + +// --- ValidateConfig --- + +func TestValidateConfig_EmptyData(t *testing.T) { + t.Parallel() + + err := ValidateConfig("test.path", nil) + if err == nil { + t.Fatal("expected error for empty data") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonMissing { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonMissing) + } +} + +func TestValidateConfig_InvalidJSON(t *testing.T) { + t.Parallel() + + err := ValidateConfig("test.path", []byte("not json")) + if err == nil { + t.Fatal("expected error for invalid JSON") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonInvalidFormat { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonInvalidFormat) + } +} + +func TestValidateConfig_ValidatorFails(t *testing.T) { + t.Parallel() + + data, _ := json.Marshal(map[string]any{"a": 1}) + failValidator := func(_ any) error { return errors.New("validation failed") } + + err := ValidateConfig("test.path", data, failValidator) + if err == nil { + t.Fatal("expected error from failing validator") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonValidationFailed { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonValidationFailed) + } +} + +func TestValidateConfig_AllValidatorsPass(t *testing.T) { + t.Parallel() + + data, _ := json.Marshal(map[string]any{"a": 1, "b": 2}) + passValidator := func(_ any) error { return nil } + + err := ValidateConfig("test.path", data, passValidator, passValidator) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestValidateConfig_NoValidators(t *testing.T) { + t.Parallel() + + data, _ := json.Marshal(map[string]any{"a": 1}) + + err := ValidateConfig("test.path", data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +// --- RequiredKeys --- + +func TestRequiredKeys_AllPresent(t *testing.T) { + t.Parallel() + + validator := RequiredKeys("host", "port") + value := map[string]any{"host": "localhost", "port": 3000, "extra": true} + + err := validator(value) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestRequiredKeys_MissingKey(t *testing.T) { + t.Parallel() + + validator := RequiredKeys("host", "port") + value := map[string]any{"host": "localhost"} + + err := validator(value) + if err == nil { + t.Fatal("expected error for missing key") + } +} + +func TestRequiredKeys_NotAMap(t *testing.T) { + t.Parallel() + + validator := RequiredKeys("key") + + err := validator("not a map") + if err == nil { + t.Fatal("expected error for non-map value") + } +} + +// --- ConfigError --- + +func TestConfigError_Error(t *testing.T) { + t.Parallel() + + err := &ConfigError{ + Path: "test.path", + Reason: ConfigReasonMissing, + Err: errors.New("not found"), + } + + got := err.Error() + if got == "" { + t.Fatal("Error() returned empty string") + } +} + +func TestConfigError_Unwrap(t *testing.T) { + t.Parallel() + + inner := errors.New("inner error") + err := &ConfigError{ + Path: "test.path", + Reason: ConfigReasonInvalidFormat, + Err: inner, + } + + if !errors.Is(err, inner) { + t.Error("Unwrap should expose inner error via errors.Is") + } +} + +func TestConfigReason_String(t *testing.T) { + t.Parallel() + + tests := []struct { + reason ConfigReason + want string + }{ + {ConfigReasonMissing, "missing"}, + {ConfigReasonInvalidFormat, "invalid_format"}, + {ConfigReasonValidationFailed, "validation_failed"}, + {ConfigReason(99), "unknown"}, + } + + for _, tt := range tests { + if got := tt.reason.String(); got != tt.want { + t.Errorf("ConfigReason(%d).String() = %q, want %q", tt.reason, got, tt.want) + } + } +} diff --git a/cli/azd/pkg/azdext/keyvault_resolver.go b/cli/azd/pkg/azdext/keyvault_resolver.go new file mode 100644 index 00000000000..9818f8d7960 --- /dev/null +++ b/cli/azd/pkg/azdext/keyvault_resolver.go @@ -0,0 +1,320 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "errors" + "fmt" + "net/http" + "regexp" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azsecrets" + "github.com/azure/azure-dev/cli/azd/pkg/keyvault" +) + +// KeyVaultResolver resolves Azure Key Vault secret references for extension +// scenarios. It uses the extension's [TokenProvider] for authentication and +// the Azure SDK data-plane client for secret retrieval. +// +// Secret references use the akvs:// URI scheme: +// +// akvs://// +// +// Usage: +// +// tp, _ := azdext.NewTokenProvider(ctx, client, nil) +// resolver, _ := azdext.NewKeyVaultResolver(tp, nil) +// value, err := resolver.Resolve(ctx, "akvs://sub-id/my-vault/my-secret") +type KeyVaultResolver struct { + credential azcore.TokenCredential + clientFactory secretClientFactory + opts KeyVaultResolverOptions +} + +// secretClientFactory abstracts secret client creation for testability. +type secretClientFactory func(vaultURL string, credential azcore.TokenCredential) (secretGetter, error) + +// secretGetter abstracts the Azure SDK secret client's GetSecret method. +type secretGetter interface { + GetSecret( + ctx context.Context, + name string, + version string, + options *azsecrets.GetSecretOptions, + ) (azsecrets.GetSecretResponse, error) +} + +// KeyVaultResolverOptions configures a [KeyVaultResolver]. +type KeyVaultResolverOptions struct { + // VaultSuffix overrides the default Key Vault DNS suffix. + // Defaults to "vault.azure.net" (Azure public cloud). + VaultSuffix string + + // ClientFactory overrides the default secret client constructor. + // Useful for testing. When nil, the production [azsecrets.NewClient] is used. + ClientFactory func(vaultURL string, credential azcore.TokenCredential) (secretGetter, error) +} + +// NewKeyVaultResolver creates a [KeyVaultResolver] with the given credential. +// +// credential must not be nil; it is typically a [*TokenProvider] from P1-1. +// If opts is nil, production defaults are used. +func NewKeyVaultResolver(credential azcore.TokenCredential, opts *KeyVaultResolverOptions) (*KeyVaultResolver, error) { + if credential == nil { + return nil, errors.New("azdext.NewKeyVaultResolver: credential must not be nil") + } + + if opts == nil { + opts = &KeyVaultResolverOptions{} + } + + if opts.VaultSuffix == "" { + opts.VaultSuffix = "vault.azure.net" + } + + factory := defaultSecretClientFactory + if opts.ClientFactory != nil { + factory = opts.ClientFactory + } + + return &KeyVaultResolver{ + credential: credential, + clientFactory: factory, + opts: *opts, + }, nil +} + +// defaultSecretClientFactory creates a real Azure SDK secrets client. +func defaultSecretClientFactory(vaultURL string, credential azcore.TokenCredential) (secretGetter, error) { + client, err := azsecrets.NewClient(vaultURL, credential, nil) + if err != nil { + return nil, err + } + + return client, nil +} + +// Resolve fetches the secret value for an akvs:// reference. +// +// The reference must match the format: akvs://// +// +// Returns a [*KeyVaultResolveError] for all domain errors (invalid reference, +// secret not found, authentication failure). No silent fallbacks or hidden retries. +func (r *KeyVaultResolver) Resolve(ctx context.Context, ref string) (string, error) { + if ctx == nil { + return "", errors.New("azdext.KeyVaultResolver.Resolve: context must not be nil") + } + + parsed, err := ParseSecretReference(ref) + if err != nil { + return "", &KeyVaultResolveError{ + Reference: ref, + Reason: ResolveReasonInvalidReference, + Err: err, + } + } + + vaultURL := fmt.Sprintf("https://%s.%s", parsed.VaultName, r.opts.VaultSuffix) + + client, err := r.clientFactory(vaultURL, r.credential) + if err != nil { + return "", &KeyVaultResolveError{ + Reference: ref, + Reason: ResolveReasonClientCreation, + Err: fmt.Errorf("failed to create Key Vault client for %s: %w", vaultURL, err), + } + } + + resp, err := client.GetSecret(ctx, parsed.SecretName, "", nil) + if err != nil { + reason := ResolveReasonAccessDenied + + var respErr *azcore.ResponseError + if errors.As(err, &respErr) { + switch respErr.StatusCode { + case http.StatusNotFound: + reason = ResolveReasonNotFound + case http.StatusForbidden, http.StatusUnauthorized: + reason = ResolveReasonAccessDenied + default: + reason = ResolveReasonServiceError + } + } + + return "", &KeyVaultResolveError{ + Reference: ref, + Reason: reason, + Err: fmt.Errorf( + "failed to retrieve secret %q from vault %q: %w", + parsed.SecretName, + parsed.VaultName, + err, + ), + } + } + + if resp.Value == nil { + return "", &KeyVaultResolveError{ + Reference: ref, + Reason: ResolveReasonNotFound, + Err: fmt.Errorf("secret %q in vault %q has a nil value", parsed.SecretName, parsed.VaultName), + } + } + + return *resp.Value, nil +} + +// ResolveMap resolves a map of key → akvs:// references, returning a map of +// key → resolved secret values. Processing stops at the first error. +// +// Non-akvs:// values are passed through unchanged, so callers can safely +// resolve a mixed map of plain values and secret references. +func (r *KeyVaultResolver) ResolveMap(ctx context.Context, refs map[string]string) (map[string]string, error) { + if ctx == nil { + return nil, errors.New("azdext.KeyVaultResolver.ResolveMap: context must not be nil") + } + + result := make(map[string]string, len(refs)) + + for key, value := range refs { + if !IsSecretReference(value) { + result[key] = value + continue + } + + resolved, err := r.Resolve(ctx, value) + if err != nil { + return nil, fmt.Errorf("azdext.KeyVaultResolver.ResolveMap: key %q: %w", key, err) + } + + result[key] = resolved + } + + return result, nil +} + +// SecretReference represents a parsed akvs:// URI. +type SecretReference struct { + // SubscriptionID is the Azure subscription containing the Key Vault. + SubscriptionID string + + // VaultName is the Key Vault name (not the full URL). + VaultName string + + // SecretName is the name of the secret within the vault. + SecretName string +} + +// IsSecretReference reports whether s uses the akvs:// scheme. +func IsSecretReference(s string) bool { + return keyvault.IsAzureKeyVaultSecret(s) +} + +// vaultNameRe validates Azure Key Vault names per Azure naming rules: +// - 3–24 characters +// - starts with a letter +// - contains only alphanumeric and hyphens +// - does not end with a hyphen +var vaultNameRe = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9-]{1,22}[a-zA-Z0-9]$`) + +// ParseSecretReference parses an akvs:// URI into its components. +// +// Expected format: akvs://// +// +// The vault name is validated against Azure Key Vault naming rules (3–24 +// characters, starts with letter, alphanumeric and hyphens only, does not +// end with a hyphen). +func ParseSecretReference(ref string) (*SecretReference, error) { + parsed, err := keyvault.ParseAzureKeyVaultSecret(ref) + if err != nil { + return nil, err + } + + if strings.TrimSpace(parsed.SubscriptionId) == "" { + return nil, fmt.Errorf("invalid akvs:// reference %q: subscription-id must not be empty", ref) + } + if strings.TrimSpace(parsed.VaultName) == "" { + return nil, fmt.Errorf("invalid akvs:// reference %q: vault-name must not be empty", ref) + } + if !vaultNameRe.MatchString(parsed.VaultName) { + return nil, fmt.Errorf( + "invalid akvs:// reference %q: vault name %q must be 3-24 characters, "+ + "start with a letter, end with alphanumeric, and contain only alphanumeric characters and hyphens", + ref, parsed.VaultName, + ) + } + if strings.TrimSpace(parsed.SecretName) == "" { + return nil, fmt.Errorf("invalid akvs:// reference %q: secret-name must not be empty", ref) + } + + return &SecretReference{ + SubscriptionID: parsed.SubscriptionId, + VaultName: parsed.VaultName, + SecretName: parsed.SecretName, + }, nil +} + +// ResolveReason classifies the cause of a [KeyVaultResolveError]. +type ResolveReason int + +const ( + // ResolveReasonInvalidReference indicates the akvs:// URI is malformed. + ResolveReasonInvalidReference ResolveReason = iota + + // ResolveReasonClientCreation indicates failure to create the Key Vault client. + ResolveReasonClientCreation + + // ResolveReasonNotFound indicates the secret does not exist. + ResolveReasonNotFound + + // ResolveReasonAccessDenied indicates an authentication or authorization failure. + ResolveReasonAccessDenied + + // ResolveReasonServiceError indicates an unexpected Key Vault service error. + ResolveReasonServiceError +) + +// String returns a human-readable label for the reason. +func (r ResolveReason) String() string { + switch r { + case ResolveReasonInvalidReference: + return "invalid_reference" + case ResolveReasonClientCreation: + return "client_creation" + case ResolveReasonNotFound: + return "not_found" + case ResolveReasonAccessDenied: + return "access_denied" + case ResolveReasonServiceError: + return "service_error" + default: + return "unknown" + } +} + +// KeyVaultResolveError is returned when [KeyVaultResolver.Resolve] fails. +type KeyVaultResolveError struct { + // Reference is the original akvs:// URI that was being resolved. + Reference string + + // Reason classifies the failure. + Reason ResolveReason + + // Err is the underlying error. + Err error +} + +func (e *KeyVaultResolveError) Error() string { + return fmt.Sprintf( + "azdext.KeyVaultResolver: %s (ref=%s): %v", + e.Reason, e.Reference, e.Err, + ) +} + +func (e *KeyVaultResolveError) Unwrap() error { + return e.Err +} diff --git a/cli/azd/pkg/azdext/keyvault_resolver_test.go b/cli/azd/pkg/azdext/keyvault_resolver_test.go new file mode 100644 index 00000000000..628f301d5d3 --- /dev/null +++ b/cli/azd/pkg/azdext/keyvault_resolver_test.go @@ -0,0 +1,575 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "errors" + "net/http" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azsecrets" +) + +// stubSecretGetter is a test double for the Key Vault data-plane client. +type stubSecretGetter struct { + resp azsecrets.GetSecretResponse + err error +} + +func (s *stubSecretGetter) GetSecret( + _ context.Context, _ string, _ string, _ *azsecrets.GetSecretOptions, +) (azsecrets.GetSecretResponse, error) { + return s.resp, s.err +} + +// stubSecretFactory returns a factory that always returns the given stubSecretGetter. +func stubSecretFactory(g secretGetter, factoryErr error) func(string, azcore.TokenCredential) (secretGetter, error) { + return func(_ string, _ azcore.TokenCredential) (secretGetter, error) { + if factoryErr != nil { + return nil, factoryErr + } + return g, nil + } +} + +// --- NewKeyVaultResolver --- + +func TestNewKeyVaultResolver_NilCredential(t *testing.T) { + t.Parallel() + + _, err := NewKeyVaultResolver(nil, nil) + if err == nil { + t.Fatal("expected error for nil credential") + } +} + +func TestNewKeyVaultResolver_Defaults(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, err := NewKeyVaultResolver(cred, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if resolver.opts.VaultSuffix != "vault.azure.net" { + t.Errorf("VaultSuffix = %q, want %q", resolver.opts.VaultSuffix, "vault.azure.net") + } +} + +func TestNewKeyVaultResolver_CustomSuffix(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, err := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + VaultSuffix: "vault.azure.cn", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if resolver.opts.VaultSuffix != "vault.azure.cn" { + t.Errorf("VaultSuffix = %q, want %q", resolver.opts.VaultSuffix, "vault.azure.cn") + } +} + +// --- IsSecretReference --- + +func TestIsSecretReference(t *testing.T) { + t.Parallel() + + tests := []struct { + input string + want bool + }{ + {"akvs://sub/vault/secret", true}, + {"akvs://", true}, + {"AKVS://sub/vault/secret", false}, // case-sensitive + {"https://vault.azure.net", false}, + {"", false}, + } + + for _, tt := range tests { + if got := IsSecretReference(tt.input); got != tt.want { + t.Errorf("IsSecretReference(%q) = %v, want %v", tt.input, got, tt.want) + } + } +} + +// --- ParseSecretReference --- + +func TestParseSecretReference_Valid(t *testing.T) { + t.Parallel() + + ref, err := ParseSecretReference("akvs://sub-123/my-vault/my-secret") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if ref.SubscriptionID != "sub-123" { + t.Errorf("SubscriptionID = %q, want %q", ref.SubscriptionID, "sub-123") + } + if ref.VaultName != "my-vault" { + t.Errorf("VaultName = %q, want %q", ref.VaultName, "my-vault") + } + if ref.SecretName != "my-secret" { + t.Errorf("SecretName = %q, want %q", ref.SecretName, "my-secret") + } +} + +func TestParseSecretReference_NotAkvsScheme(t *testing.T) { + t.Parallel() + + _, err := ParseSecretReference("https://vault.azure.net/secrets/x") + if err == nil { + t.Fatal("expected error for non-akvs scheme") + } +} + +func TestParseSecretReference_TooFewParts(t *testing.T) { + t.Parallel() + + _, err := ParseSecretReference("akvs://sub/vault") + if err == nil { + t.Fatal("expected error for two-part ref") + } +} + +func TestParseSecretReference_TooManyParts(t *testing.T) { + t.Parallel() + + _, err := ParseSecretReference("akvs://sub/vault/secret/extra") + if err == nil { + t.Fatal("expected error for four-part ref") + } +} + +func TestParseSecretReference_EmptyComponent(t *testing.T) { + t.Parallel() + + cases := []string{ + "akvs:///vault/secret", // empty subscription + "akvs://sub//secret", // empty vault + "akvs://sub/vault/", // empty secret + "akvs:// /vault/secret", // whitespace subscription + "akvs://sub/ /secret", // whitespace vault + "akvs://sub/vault/ ", // whitespace secret + } + + for _, ref := range cases { + _, err := ParseSecretReference(ref) + if err == nil { + t.Errorf("ParseSecretReference(%q) expected error, got nil", ref) + } + } +} + +// --- Resolve --- + +func TestResolve_Success(t *testing.T) { + t.Parallel() + + secretValue := "super-secret-value" + getter := &stubSecretGetter{ + resp: azsecrets.GetSecretResponse{ + Secret: azsecrets.Secret{ + Value: &secretValue, + }, + }, + } + + cred := &stubCredential{} + resolver, err := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + val, err := resolver.Resolve(context.Background(), "akvs://sub-id/my-vault/my-secret") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if val != secretValue { + t.Errorf("Resolve() = %q, want %q", val, secretValue) + } +} + +func TestResolve_NilContext(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(&stubSecretGetter{}, nil), + }) + + //nolint:staticcheck // intentionally testing nil context + _, err := resolver.Resolve(nil, "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for nil context") + } +} + +func TestResolve_InvalidReference(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(&stubSecretGetter{}, nil), + }) + + _, err := resolver.Resolve(context.Background(), "not-akvs://x") + if err == nil { + t.Fatal("expected error for invalid reference") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonInvalidReference { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonInvalidReference) + } +} + +func TestResolve_ClientCreationFailure(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(nil, errors.New("connection refused")), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for client creation failure") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonClientCreation { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonClientCreation) + } +} + +func TestResolve_SecretNotFound(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: &azcore.ResponseError{StatusCode: http.StatusNotFound}, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/missing-secret") + if err == nil { + t.Fatal("expected error for missing secret") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonNotFound { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonNotFound) + } +} + +func TestResolve_AccessDenied(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: &azcore.ResponseError{StatusCode: http.StatusForbidden}, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for forbidden access") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonAccessDenied { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonAccessDenied) + } +} + +func TestResolve_Unauthorized(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: &azcore.ResponseError{StatusCode: http.StatusUnauthorized}, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for unauthorized access") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonAccessDenied { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonAccessDenied) + } +} + +func TestResolve_ServiceError(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: &azcore.ResponseError{StatusCode: http.StatusInternalServerError}, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for server error") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonServiceError { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonServiceError) + } +} + +func TestResolve_NilValue(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + resp: azsecrets.GetSecretResponse{ + Secret: azsecrets.Secret{ + Value: nil, + }, + }, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for nil secret value") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonNotFound { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonNotFound) + } +} + +func TestResolve_NonResponseError(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: errors.New("network timeout"), + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for network failure") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + // Non-ResponseError defaults to access_denied + if resolveErr.Reason != ResolveReasonAccessDenied { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonAccessDenied) + } +} + +// --- ResolveMap --- + +func TestResolveMap_MixedValues(t *testing.T) { + t.Parallel() + + secretValue := "resolved-secret" + getter := &stubSecretGetter{ + resp: azsecrets.GetSecretResponse{ + Secret: azsecrets.Secret{ + Value: &secretValue, + }, + }, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + input := map[string]string{ + "plain": "hello-world", + "secret": "akvs://sub/vault/secret", + } + + result, err := resolver.ResolveMap(context.Background(), input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result["plain"] != "hello-world" { + t.Errorf("result[plain] = %q, want %q", result["plain"], "hello-world") + } + + if result["secret"] != secretValue { + t.Errorf("result[secret] = %q, want %q", result["secret"], secretValue) + } +} + +func TestResolveMap_Empty(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(&stubSecretGetter{}, nil), + }) + + result, err := resolver.ResolveMap(context.Background(), map[string]string{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result) != 0 { + t.Errorf("len(result) = %d, want 0", len(result)) + } +} + +func TestResolveMap_ErrorStopsProcessing(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: &azcore.ResponseError{StatusCode: http.StatusNotFound}, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + input := map[string]string{ + "secret": "akvs://sub/vault/missing", + } + + _, err := resolver.ResolveMap(context.Background(), input) + if err == nil { + t.Fatal("expected error when resolution fails") + } +} + +func TestResolveMap_NilContext(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(&stubSecretGetter{}, nil), + }) + + //nolint:staticcheck // intentionally testing nil context + _, err := resolver.ResolveMap(nil, map[string]string{"k": "v"}) + if err == nil { + t.Fatal("expected error for nil context") + } +} + +// --- Error types --- + +func TestKeyVaultResolveError_Error(t *testing.T) { + t.Parallel() + + err := &KeyVaultResolveError{ + Reference: "akvs://sub/vault/secret", + Reason: ResolveReasonNotFound, + Err: errors.New("secret not found"), + } + + got := err.Error() + if got == "" { + t.Fatal("Error() returned empty string") + } +} + +func TestKeyVaultResolveError_Unwrap(t *testing.T) { + t.Parallel() + + inner := errors.New("inner error") + err := &KeyVaultResolveError{ + Reference: "akvs://sub/vault/secret", + Reason: ResolveReasonServiceError, + Err: inner, + } + + if !errors.Is(err, inner) { + t.Error("Unwrap should expose inner error via errors.Is") + } +} + +func TestResolveReason_String(t *testing.T) { + t.Parallel() + + tests := []struct { + reason ResolveReason + want string + }{ + {ResolveReasonInvalidReference, "invalid_reference"}, + {ResolveReasonClientCreation, "client_creation"}, + {ResolveReasonNotFound, "not_found"}, + {ResolveReasonAccessDenied, "access_denied"}, + {ResolveReasonServiceError, "service_error"}, + {ResolveReason(99), "unknown"}, + } + + for _, tt := range tests { + if got := tt.reason.String(); got != tt.want { + t.Errorf("ResolveReason(%d).String() = %q, want %q", tt.reason, got, tt.want) + } + } +} diff --git a/cli/azd/pkg/azdext/logger.go b/cli/azd/pkg/azdext/logger.go new file mode 100644 index 00000000000..d0d156ca565 --- /dev/null +++ b/cli/azd/pkg/azdext/logger.go @@ -0,0 +1,175 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "io" + "log/slog" + "os" + "strings" +) + +// LoggerOptions configures [SetupLogging] and [NewLogger]. +type LoggerOptions struct { + // Debug enables debug-level logging. When false, messages below Info are + // suppressed. If not set explicitly, [NewLogger] checks the AZD_DEBUG + // environment variable. + Debug bool + // Structured selects JSON output when true, human-readable text when false. + Structured bool + // Writer overrides the output destination. Defaults to os.Stderr. + Writer io.Writer +} + +// SetupLogging configures the process-wide default [slog.Logger]. +// It is typically called once at startup (for example from +// [NewExtensionRootCommand]'s PersistentPreRunE callback). +// +// Calling SetupLogging is optional — [NewLogger] works without it and creates +// loggers that inherit from [slog.Default]. SetupLogging is provided for +// extensions that want explicit control over the global log level and format. +func SetupLogging(opts LoggerOptions) { + handler := newHandler(opts) + slog.SetDefault(slog.New(handler)) +} + +// Logger provides component-scoped structured logging built on [log/slog]. +// +// Each Logger carries a "component" attribute so log lines can be filtered or +// routed by subsystem. Additional context can be attached via [Logger.With], +// [Logger.WithComponent], or [Logger.WithOperation]. +// +// Logger writes to stderr by default and never writes to stdout, so it does +// not interfere with command output or JSON-mode piping. +type Logger struct { + slogger *slog.Logger + component string +} + +// NewLogger creates a Logger scoped to the given component name. +// +// If the AZD_DEBUG environment variable is set to a truthy value ("1", "true", +// "yes") and opts.Debug is false, debug logging is enabled automatically. This +// lets extension authors respect the framework's debug flag without extra +// plumbing. +// +// When opts is omitted (zero value), the logger uses Info level with text +// format on stderr. +func NewLogger(component string, opts ...LoggerOptions) *Logger { + var o LoggerOptions + if len(opts) > 0 { + o = opts[0] + } + + // Auto-detect debug from environment when not explicitly set. + if !o.Debug { + o.Debug = isDebugEnv() + } + + handler := newHandler(o) + base := slog.New(handler).With("component", component) + + return &Logger{ + slogger: base, + component: component, + } +} + +// Component returns the component name this logger was created with. +func (l *Logger) Component() string { + return l.component +} + +// Debug logs a message at debug level with optional key-value pairs. +func (l *Logger) Debug(msg string, args ...any) { + l.slogger.Debug(msg, args...) +} + +// Info logs a message at info level with optional key-value pairs. +func (l *Logger) Info(msg string, args ...any) { + l.slogger.Info(msg, args...) +} + +// Warn logs a message at warn level with optional key-value pairs. +func (l *Logger) Warn(msg string, args ...any) { + l.slogger.Warn(msg, args...) +} + +// Error logs a message at error level with optional key-value pairs. +func (l *Logger) Error(msg string, args ...any) { + l.slogger.Error(msg, args...) +} + +// With returns a new Logger that includes the given key-value pairs in every +// subsequent log entry. Keys must be strings; values can be any type +// supported by [slog]. +// +// Example: +// +// l := logger.With("request_id", reqID) +// l.Info("processing") // includes component + request_id +func (l *Logger) With(args ...any) *Logger { + return &Logger{ + slogger: l.slogger.With(args...), + component: l.component, + } +} + +// WithComponent returns a new Logger with a different component name. The +// original component is preserved as "parent_component". +func (l *Logger) WithComponent(name string) *Logger { + return &Logger{ + slogger: l.slogger.With("parent_component", l.component, "component", name), + component: name, + } +} + +// WithOperation returns a new Logger with an "operation" attribute. +func (l *Logger) WithOperation(name string) *Logger { + return &Logger{ + slogger: l.slogger.With("operation", name), + component: l.component, + } +} + +// Slogger returns the underlying [*slog.Logger] for advanced use cases such +// as passing to libraries that accept a standard slog logger. +func (l *Logger) Slogger() *slog.Logger { + return l.slogger +} + +// --------------------------------------------------------------------------- +// Internal helpers +// --------------------------------------------------------------------------- + +// newHandler creates an slog.Handler from LoggerOptions. +func newHandler(opts LoggerOptions) slog.Handler { + w := opts.Writer + if w == nil { + w = os.Stderr + } + + level := slog.LevelInfo + if opts.Debug { + level = slog.LevelDebug + } + + handlerOpts := &slog.HandlerOptions{Level: level} + + if opts.Structured { + return slog.NewJSONHandler(w, handlerOpts) + } + return slog.NewTextHandler(w, handlerOpts) +} + +// isDebugEnv checks the AZD_DEBUG environment variable. +// +// Security note: AZD_DEBUG enables verbose logging that may include +// request details, configuration paths, and internal state. It should +// NOT be enabled in production deployments. The variable is intended +// for local development and CI debugging only. +func isDebugEnv() bool { + v := strings.ToLower(os.Getenv("AZD_DEBUG")) + return v == "1" || v == "true" || v == "yes" +} diff --git a/cli/azd/pkg/azdext/logger_test.go b/cli/azd/pkg/azdext/logger_test.go new file mode 100644 index 00000000000..9e21b763ada --- /dev/null +++ b/cli/azd/pkg/azdext/logger_test.go @@ -0,0 +1,295 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "bytes" + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// NewLogger — basic construction +// --------------------------------------------------------------------------- + +func TestNewLogger_DefaultOptions(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("test-component", LoggerOptions{Writer: &buf}) + + require.NotNil(t, logger) + require.Equal(t, "test-component", logger.Component()) +} + +func TestNewLogger_ZeroOptions(t *testing.T) { + // Zero-value opts should not panic (writes to stderr). + logger := NewLogger("safe") + require.NotNil(t, logger) + require.Equal(t, "safe", logger.Component()) +} + +func TestNewLogger_NoOpts(t *testing.T) { + // Calling without variadic opts should not panic. + logger := NewLogger("minimal") + require.NotNil(t, logger) +} + +// --------------------------------------------------------------------------- +// Log levels — Info +// --------------------------------------------------------------------------- + +func TestLogger_Info(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("mycomp", LoggerOptions{Writer: &buf}) + + logger.Info("hello world", "key", "val") + + output := buf.String() + require.Contains(t, output, "hello world") + require.Contains(t, output, "key=val") + require.Contains(t, output, "component=mycomp") +} + +// --------------------------------------------------------------------------- +// Log levels — Debug +// --------------------------------------------------------------------------- + +func TestLogger_Debug_Enabled(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("dbg", LoggerOptions{Debug: true, Writer: &buf}) + + logger.Debug("debug message", "detail", "x") + + require.Contains(t, buf.String(), "debug message") + require.Contains(t, buf.String(), "detail=x") +} + +func TestLogger_Debug_Disabled(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("dbg", LoggerOptions{Debug: false, Writer: &buf}) + + logger.Debug("should not appear") + + require.Empty(t, buf.String()) +} + +func TestLogger_Debug_AZD_DEBUG_EnvVar(t *testing.T) { + t.Setenv("AZD_DEBUG", "true") + + var buf bytes.Buffer + logger := NewLogger("env-debug", LoggerOptions{Writer: &buf}) + + logger.Debug("from env var") + + require.Contains(t, buf.String(), "from env var") +} + +func TestLogger_Debug_AZD_DEBUG_EnvVar_One(t *testing.T) { + t.Setenv("AZD_DEBUG", "1") + + var buf bytes.Buffer + logger := NewLogger("env-one", LoggerOptions{Writer: &buf}) + + logger.Debug("debug via 1") + + require.Contains(t, buf.String(), "debug via 1") +} + +func TestLogger_Debug_AZD_DEBUG_EnvVar_Yes(t *testing.T) { + t.Setenv("AZD_DEBUG", "yes") + + var buf bytes.Buffer + logger := NewLogger("env-yes", LoggerOptions{Writer: &buf}) + + logger.Debug("debug via yes") + + require.Contains(t, buf.String(), "debug via yes") +} + +func TestLogger_Debug_AZD_DEBUG_Unset(t *testing.T) { + t.Setenv("AZD_DEBUG", "") + + var buf bytes.Buffer + logger := NewLogger("env-empty", LoggerOptions{Writer: &buf}) + + logger.Debug("hidden") + + require.Empty(t, buf.String()) +} + +// --------------------------------------------------------------------------- +// Log levels — Warn / Error +// --------------------------------------------------------------------------- + +func TestLogger_Warn(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("warn-test", LoggerOptions{Writer: &buf}) + + logger.Warn("something concerning", "retries", 3) + + require.Contains(t, buf.String(), "something concerning") + require.Contains(t, buf.String(), "retries=3") +} + +func TestLogger_Error(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("err-test", LoggerOptions{Writer: &buf}) + + logger.Error("bad thing happened", "code", 500) + + require.Contains(t, buf.String(), "bad thing happened") + require.Contains(t, buf.String(), "code=500") +} + +// --------------------------------------------------------------------------- +// Structured (JSON) output +// --------------------------------------------------------------------------- + +func TestLogger_StructuredJSON(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("json-comp", LoggerOptions{Structured: true, Writer: &buf}) + + logger.Info("structured entry", "env", "prod") + + // Each line should be valid JSON. + var parsed map[string]any + err := json.Unmarshal(buf.Bytes(), &parsed) + require.NoError(t, err) + require.Equal(t, "structured entry", parsed["msg"]) + require.Equal(t, "prod", parsed["env"]) + require.Equal(t, "json-comp", parsed["component"]) +} + +// --------------------------------------------------------------------------- +// Context chaining — With +// --------------------------------------------------------------------------- + +func TestLogger_With(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("base", LoggerOptions{Writer: &buf}) + + child := logger.With("request_id", "abc-123") + child.Info("processing") + + output := buf.String() + require.Contains(t, output, "request_id=abc-123") + require.Contains(t, output, "component=base") + require.Contains(t, output, "processing") + + // Child should have the same component. + require.Equal(t, "base", child.Component()) +} + +func TestLogger_With_ChainMultiple(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("chain", LoggerOptions{Writer: &buf}) + + child := logger.With("a", "1").With("b", "2") + child.Info("chained") + + output := buf.String() + require.Contains(t, output, "a=1") + require.Contains(t, output, "b=2") +} + +// --------------------------------------------------------------------------- +// Context chaining — WithComponent +// --------------------------------------------------------------------------- + +func TestLogger_WithComponent(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("parent", LoggerOptions{Structured: true, Writer: &buf}) + + child := logger.WithComponent("child-subsystem") + child.Info("from child") + + require.Equal(t, "child-subsystem", child.Component()) + + var parsed map[string]any + err := json.Unmarshal(buf.Bytes(), &parsed) + require.NoError(t, err) + require.Equal(t, "child-subsystem", parsed["component"]) + require.Equal(t, "parent", parsed["parent_component"]) +} + +// --------------------------------------------------------------------------- +// Context chaining — WithOperation +// --------------------------------------------------------------------------- + +func TestLogger_WithOperation(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("ops", LoggerOptions{Writer: &buf}) + + child := logger.WithOperation("deploy") + child.Info("starting deploy") + + output := buf.String() + require.Contains(t, output, "operation=deploy") + require.Equal(t, "ops", child.Component()) +} + +// --------------------------------------------------------------------------- +// Slogger accessor +// --------------------------------------------------------------------------- + +func TestLogger_Slogger(t *testing.T) { + logger := NewLogger("access", LoggerOptions{Writer: &bytes.Buffer{}}) + require.NotNil(t, logger.Slogger()) +} + +// --------------------------------------------------------------------------- +// SetupLogging — global logger configuration +// --------------------------------------------------------------------------- + +func TestSetupLogging_DoesNotPanic(t *testing.T) { + // SetupLogging modifies slog.Default which is global state. + // We only verify it does not panic here. + var buf bytes.Buffer + SetupLogging(LoggerOptions{Debug: true, Structured: true, Writer: &buf}) + + // Restore a sensible default after the test. + SetupLogging(LoggerOptions{Writer: &bytes.Buffer{}}) +} + +// --------------------------------------------------------------------------- +// isDebugEnv internal helper +// --------------------------------------------------------------------------- + +func TestIsDebugEnv_Truthy(t *testing.T) { + truthy := []string{"1", "true", "TRUE", "True", "yes", "YES", "Yes"} + for _, v := range truthy { + t.Run(v, func(t *testing.T) { + t.Setenv("AZD_DEBUG", v) + require.True(t, isDebugEnv()) + }) + } +} + +func TestIsDebugEnv_Falsy(t *testing.T) { + falsy := []string{"", "0", "false", "no", "maybe"} + for _, v := range falsy { + t.Run("value="+v, func(t *testing.T) { + t.Setenv("AZD_DEBUG", v) + require.False(t, isDebugEnv()) + }) + } +} + +// --------------------------------------------------------------------------- +// Text format verification +// --------------------------------------------------------------------------- + +func TestLogger_TextFormat_ContainsLevel(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("lvl", LoggerOptions{Writer: &buf}) + + logger.Info("test level") + + output := buf.String() + require.True(t, + strings.Contains(output, "INFO") || strings.Contains(output, "level=INFO"), + "expected level indicator in text output: %s", output) +} diff --git a/cli/azd/pkg/azdext/mcp_security.go b/cli/azd/pkg/azdext/mcp_security.go index c9c2122eaeb..a7bfb67797f 100644 --- a/cli/azd/pkg/azdext/mcp_security.go +++ b/cli/azd/pkg/azdext/mcp_security.go @@ -6,6 +6,7 @@ package azdext import ( "fmt" "net" + "net/http" "net/url" "os" "path/filepath" @@ -23,6 +24,8 @@ type MCPSecurityPolicy struct { allowedBasePaths []string blockedCIDRs []*net.IPNet blockedHosts map[string]bool + // onBlocked is invoked whenever a URL or path is blocked, for audit logging. + onBlocked func(violation string) // lookupHost is used for DNS resolution; override in tests. lookupHost func(string) ([]string, error) } @@ -42,12 +45,7 @@ func (p *MCPSecurityPolicy) BlockMetadataEndpoints() *MCPSecurityPolicy { p.mu.Lock() defer p.mu.Unlock() p.blockMetadata = true - for _, host := range []string{ - "169.254.169.254", - "fd00:ec2::254", - "metadata.google.internal", - "100.100.100.200", - } { + for _, host := range ssrfMetadataHosts { p.blockedHosts[strings.ToLower(host)] = true } return p @@ -60,23 +58,7 @@ func (p *MCPSecurityPolicy) BlockPrivateNetworks() *MCPSecurityPolicy { p.mu.Lock() defer p.mu.Unlock() p.blockPrivate = true - for _, cidr := range []string{ - "0.0.0.0/8", // "this" network (reaches loopback on Linux/macOS) - "10.0.0.0/8", // RFC 1918 private - "172.16.0.0/12", // RFC 1918 private - "192.168.0.0/16", // RFC 1918 private - "127.0.0.0/8", // loopback - "100.64.0.0/10", // RFC 6598 shared/CGNAT (internal in cloud environments) - "169.254.0.0/16", // IPv4 link-local - "::1/128", // IPv6 loopback - "::/128", // IPv6 unspecified (reaches loopback) - "fc00::/7", // IPv6 unique local addresses (RFC 4193, equiv of RFC 1918) - "fe80::/10", // IPv6 link-local - "2002::/16", // 6to4 relay (deprecated RFC 7526; can embed private IPv4) - "2001::/32", // Teredo tunneling (deprecated; can embed private IPv4) - "64:ff9b::/96", // NAT64 well-known prefix (RFC 6052; embeds IPv4 in last 32 bits) - "64:ff9b:1::/48", // NAT64 local-use prefix (RFC 8215; embeds IPv4 in last 32 bits) - } { + for _, cidr := range ssrfBlockedCIDRs { _, ipNet, err := net.ParseCIDR(cidr) if err == nil { p.blockedCIDRs = append(p.blockedCIDRs, ipNet) @@ -111,6 +93,27 @@ func (p *MCPSecurityPolicy) ValidatePathsWithinBase(basePaths ...string) *MCPSec return p } +// OnBlocked registers a callback invoked whenever a URL or path check fails. +// The callback receives a human-readable description of the violation. +// This is intended for audit logging; the callback must not block. +func (p *MCPSecurityPolicy) OnBlocked(fn func(violation string)) *MCPSecurityPolicy { + p.mu.Lock() + defer p.mu.Unlock() + p.onBlocked = fn + return p +} + +// notifyBlocked invokes the onBlocked callback if registered. +func (p *MCPSecurityPolicy) notifyBlocked(violation string) { + p.mu.RLock() + onBlocked := p.onBlocked + p.mu.RUnlock() + + if onBlocked != nil { + onBlocked(violation) + } +} + // isLocalhostHost returns true if the host is localhost or a loopback address. func isLocalhostHost(host string) bool { h := strings.ToLower(host) @@ -125,8 +128,17 @@ func isLocalhostHost(host string) bool { // Returns an error describing the violation, or nil if allowed. func (p *MCPSecurityPolicy) CheckURL(rawURL string) error { p.mu.RLock() - defer p.mu.RUnlock() + err := p.checkURLCore(rawURL) + p.mu.RUnlock() + + if err != nil { + p.notifyBlocked(err.Error()) + } + return err +} + +func (p *MCPSecurityPolicy) checkURLCore(rawURL string) error { u, err := url.Parse(rawURL) if err != nil { return fmt.Errorf("invalid URL: %w", err) @@ -153,6 +165,13 @@ func (p *MCPSecurityPolicy) CheckURL(rawURL string) error { // If the host is an IP literal, check it directly against blocked CIDRs. if ip := net.ParseIP(host); ip != nil { + // Check normalized IP form against blocked hosts — catches IPv4-mapped + // IPv6 forms like ::ffff:169.254.169.254 that bypass string matching. + if normalizedIP := ip.String(); normalizedIP != host { + if p.blockedHosts[strings.ToLower(normalizedIP)] { + return fmt.Errorf("blocked host: %s (normalized: %s)", host, normalizedIP) + } + } if err := p.checkIP(ip, host); err != nil { return err } @@ -180,70 +199,8 @@ func (p *MCPSecurityPolicy) CheckURL(rawURL string) error { } func (p *MCPSecurityPolicy) checkIP(ip net.IP, originalHost string) error { - for _, cidr := range p.blockedCIDRs { - if cidr.Contains(ip) { - return fmt.Errorf("blocked IP %s (CIDR %s) for host %s", ip, cidr, originalHost) - } - } - - if p.blockPrivate { - // Catch encoding variants (e.g., IPv4-compatible IPv6 like ::127.0.0.1) - // that may not match CIDR entries due to byte-length mismatch. - if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsUnspecified() { - return fmt.Errorf("blocked IP %s (private/loopback/link-local) for host %s", ip, originalHost) - } - - // Handle encoding variants that Go's net.IP methods don't classify, by extracting - // the embedded IPv4 address and re-checking it against all blocked ranges. - if len(ip) == net.IPv6len && ip.To4() == nil { - // IPv4-compatible (::x.x.x.x, RFC 4291 §2.5.5.1): first 12 bytes are zero. - isV4Compatible := true - for i := 0; i < 12; i++ { - if ip[i] != 0 { - isV4Compatible = false - break - } - } - if isV4Compatible && (ip[12] != 0 || ip[13] != 0 || ip[14] != 0 || ip[15] != 0) { - v4 := net.IPv4(ip[12], ip[13], ip[14], ip[15]) - for _, cidr := range p.blockedCIDRs { - if cidr.Contains(v4) { - return fmt.Errorf("blocked IP %s (IPv4-compatible %s, CIDR %s) for host %s", - ip, v4, cidr, originalHost) - } - } - if v4.IsLoopback() || v4.IsPrivate() || v4.IsLinkLocalUnicast() || v4.IsUnspecified() { - return fmt.Errorf("blocked IP %s (IPv4-compatible %s, private/loopback) for host %s", - ip, v4, originalHost) - } - } - - // IPv4-translated (::ffff:0:x.x.x.x, RFC 2765 §4.2.1): bytes 0-7 zero, - // bytes 8-9 = 0xFF 0xFF, bytes 10-11 = 0x00 0x00, bytes 12-15 = IPv4. - // Distinct from IPv4-mapped (bytes 10-11 = 0xFF), so To4() returns nil. - isV4Translated := ip[8] == 0xFF && ip[9] == 0xFF && ip[10] == 0x00 && ip[11] == 0x00 - if isV4Translated { - for i := 0; i < 8; i++ { - if ip[i] != 0 { - isV4Translated = false - break - } - } - } - if isV4Translated && (ip[12] != 0 || ip[13] != 0 || ip[14] != 0 || ip[15] != 0) { - v4 := net.IPv4(ip[12], ip[13], ip[14], ip[15]) - for _, cidr := range p.blockedCIDRs { - if cidr.Contains(v4) { - return fmt.Errorf("blocked IP %s (IPv4-translated %s, CIDR %s) for host %s", - ip, v4, cidr, originalHost) - } - } - if v4.IsLoopback() || v4.IsPrivate() || v4.IsLinkLocalUnicast() || v4.IsUnspecified() { - return fmt.Errorf("blocked IP %s (IPv4-translated %s, private/loopback) for host %s", - ip, v4, originalHost) - } - } - } + if _, detail, blocked := ssrfCheckIP(ip, originalHost, p.blockedCIDRs, p.blockPrivate); blocked { + return fmt.Errorf("%s", detail) } return nil @@ -251,10 +208,31 @@ func (p *MCPSecurityPolicy) checkIP(ip net.IP, originalHost string) error { // CheckPath validates a file path against the security policy. // Resolves symlinks and checks for directory traversal. +// +// SECURITY NOTE (TOCTOU): This check is inherently susceptible to +// time-of-check-to-time-of-use races — the filesystem state may change between +// the validation here and the actual file access by the caller. +// +// Mitigations callers should consider: +// - Use O_NOFOLLOW when opening files after validation. +// - Use file-descriptor-based approaches (openat2 with RESOLVE_BENEATH on +// Linux 5.6+) where possible. +// - Open the file immediately after validation and re-verify the resolved +// path via /proc/self/fd/N or fstat before processing. +// - Avoid writing to directories that untrusted users can modify. func (p *MCPSecurityPolicy) CheckPath(path string) error { p.mu.RLock() - defer p.mu.RUnlock() + err := p.checkPathCore(path) + p.mu.RUnlock() + if err != nil { + p.notifyBlocked(err.Error()) + } + + return err +} + +func (p *MCPSecurityPolicy) checkPathCore(path string) error { if len(p.allowedBasePaths) == 0 { return nil } @@ -348,3 +326,85 @@ func resolveExistingPrefix(p string) string { } } } + +// redirectBlockedHosts lists hostnames that HTTP redirects must never follow. +// This covers cloud metadata services that attackers commonly target via +// redirect-based SSRF (the initial request hits an allowed host which 302s +// to the metadata endpoint). +var redirectBlockedHosts = map[string]bool{ + "169.254.169.254": true, + "fd00:ec2::254": true, + "metadata.google.internal": true, + "100.100.100.200": true, +} + +// SSRFSafeRedirect is an http.Client CheckRedirect function that blocks +// redirects to cloud metadata endpoints, private/loopback addresses, and +// non-HTTP(S) schemes. Hostnames in redirect targets are resolved via DNS +// and all resulting IPs are checked. DNS failures are treated as blocked +// (fail-closed) to prevent bypass via DNS rebinding. Assign it to +// http.Client.CheckRedirect. +func SSRFSafeRedirect(req *http.Request, via []*http.Request) error { + host := strings.ToLower(req.URL.Hostname()) + + // Block known metadata endpoints (string match). + if redirectBlockedHosts[host] { + return fmt.Errorf("redirect to blocked metadata host: %s", host) + } + + if ip := net.ParseIP(host); ip != nil { + // Check normalized IP against metadata list — catches IPv4-mapped IPv6 + // forms like [::ffff:169.254.169.254] that bypass string matching. + if normalizedIP := ip.String(); redirectBlockedHosts[strings.ToLower(normalizedIP)] { + return fmt.Errorf("redirect to blocked metadata host: %s", host) + } + if err := checkRedirectIP(ip, host); err != nil { + return err + } + } else { + // Hostname — resolve and check all IPs (fail-closed on DNS failure). + addrs, err := net.LookupHost(host) + if err != nil { + return fmt.Errorf( + "redirect to %s blocked: DNS resolution failed (fail-closed, SSRF protection)", host, + ) + } + for _, addr := range addrs { + if redirectBlockedHosts[strings.ToLower(addr)] { + return fmt.Errorf( + "redirect to %s blocked: resolves to metadata endpoint %s (SSRF protection)", + host, addr, + ) + } + if resolvedIP := net.ParseIP(addr); resolvedIP != nil { + if err := checkRedirectIP(resolvedIP, host); err != nil { + return err + } + } + } + } + + // Block non-HTTP(S) scheme redirects (e.g., file://, gopher://). + switch req.URL.Scheme { + case "http", "https": + // allowed + default: + return fmt.Errorf("redirect to disallowed scheme: %s", req.URL.Scheme) + } + + // Enforce standard redirect limit. + if len(via) >= 10 { + return fmt.Errorf("stopped after %d redirects", len(via)) + } + + return nil +} + +// checkRedirectIP checks whether an IP is in a private, loopback, link-local, +// or unspecified range. Used by [SSRFSafeRedirect] to block redirect-based SSRF. +func checkRedirectIP(ip net.IP, host string) error { + if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsUnspecified() { + return fmt.Errorf("redirect to private/loopback IP %s blocked (SSRF protection, host: %s)", ip, host) + } + return nil +} diff --git a/cli/azd/pkg/azdext/mcp_security_test.go b/cli/azd/pkg/azdext/mcp_security_test.go index 6cc0d1d44da..cb9610dc18e 100644 --- a/cli/azd/pkg/azdext/mcp_security_test.go +++ b/cli/azd/pkg/azdext/mcp_security_test.go @@ -5,6 +5,8 @@ package azdext import ( "fmt" + "net/http" + "net/url" "os" "path/filepath" "strings" @@ -19,6 +21,9 @@ func TestMCPSecurityCheckURL_BlocksMetadataEndpoints(t *testing.T) { "http://fd00:ec2::254/latest/meta-data/", "http://metadata.google.internal/computeMetadata/v1/", "http://100.100.100.200/latest/meta-data/", + // IPv4-mapped forms of metadata IPs — must be caught by IP normalization. + "http://[::ffff:169.254.169.254]/latest/meta-data/", + "http://[::ffff:100.100.100.200]/latest/meta-data/", } for _, u := range blocked { if err := policy.CheckURL(u); err == nil { @@ -27,6 +32,25 @@ func TestMCPSecurityCheckURL_BlocksMetadataEndpoints(t *testing.T) { } } +func TestSSRFSafeRedirect_BlocksPrivateHostnames(t *testing.T) { + t.Parallel() + tests := []struct { + host string + blocked bool + }{ + {"169.254.169.254", true}, + {"127.0.0.1", true}, + {"10.0.0.1", true}, + } + for _, tc := range tests { + req := &http.Request{URL: mustParseURL(t, "http://"+tc.host+"/path")} + err := SSRFSafeRedirect(req, nil) + if tc.blocked && err == nil { + t.Errorf("SSRFSafeRedirect(%s) = nil, want error", tc.host) + } + } +} + func TestMCPSecurityCheckURL_BlocksPrivateIPs(t *testing.T) { policy := NewMCPSecurityPolicy().BlockPrivateNetworks() @@ -335,3 +359,12 @@ func TestMCPSecurityFluentBuilder(t *testing.T) { t.Errorf("expected 1 base path, got %d", len(policy.allowedBasePaths)) } } + +func mustParseURL(t *testing.T, rawURL string) *url.URL { + t.Helper() + u, err := url.Parse(rawURL) + if err != nil { + t.Fatalf("mustParseURL(%q): %v", rawURL, err) + } + return u +} diff --git a/cli/azd/pkg/azdext/output.go b/cli/azd/pkg/azdext/output.go new file mode 100644 index 00000000000..ada93d09569 --- /dev/null +++ b/cli/azd/pkg/azdext/output.go @@ -0,0 +1,273 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "encoding/json" + "fmt" + "io" + "os" + "strings" + + "github.com/fatih/color" +) + +// OutputFormat represents the output format for extension commands. +type OutputFormat string + +const ( + // OutputFormatDefault is human-readable text with optional color. + OutputFormatDefault OutputFormat = "default" + // OutputFormatJSON outputs structured JSON for machine consumption. + OutputFormatJSON OutputFormat = "json" +) + +// ParseOutputFormat converts a string to an OutputFormat. +// Returns OutputFormatDefault for unrecognized values and a non-nil error. +func ParseOutputFormat(s string) (OutputFormat, error) { + switch strings.ToLower(s) { + case "default", "": + return OutputFormatDefault, nil + case "json": + return OutputFormatJSON, nil + default: + return OutputFormatDefault, fmt.Errorf("invalid output format %q (valid: default, json)", s) + } +} + +// OutputOptions configures an [Output] instance. +type OutputOptions struct { + // Format controls the output style. Defaults to OutputFormatDefault. + Format OutputFormat + // Writer is the destination for normal output. Defaults to os.Stdout. + Writer io.Writer + // ErrWriter is the destination for error/warning output. Defaults to os.Stderr. + ErrWriter io.Writer +} + +// Output provides formatted, format-aware output for extension commands. +// In default mode it writes human-readable text with ANSI color; in JSON mode +// it writes structured JSON objects to stdout and suppresses decorative output. +// +// Output is safe for use from a single goroutine. If concurrent use is needed +// callers should synchronize externally. +type Output struct { + writer io.Writer + errWriter io.Writer + format OutputFormat + + // Color printers — configured once at construction. + successColor *color.Color + warningColor *color.Color + errorColor *color.Color + infoColor *color.Color + headerColor *color.Color + dimColor *color.Color +} + +// NewOutput creates an Output configured by opts. +// If opts.Writer or opts.ErrWriter are nil they default to os.Stdout / os.Stderr. +func NewOutput(opts OutputOptions) *Output { + w := opts.Writer + if w == nil { + w = os.Stdout + } + ew := opts.ErrWriter + if ew == nil { + ew = os.Stderr + } + + return &Output{ + writer: w, + errWriter: ew, + format: opts.Format, + successColor: color.New(color.FgGreen), + warningColor: color.New(color.FgYellow), + errorColor: color.New(color.FgRed), + infoColor: color.New(color.FgCyan), + headerColor: color.New(color.Bold), + dimColor: color.New(color.Faint), + } +} + +// IsJSON returns true when the output format is JSON. +// Callers can use this to skip decorative output that is only relevant in +// human-readable mode. +func (o *Output) IsJSON() bool { + return o.format == OutputFormatJSON +} + +// Success prints a success message prefixed with a green check mark. +// In JSON mode the call is a no-op (use [Output.JSON] for structured data). +func (o *Output) Success(format string, args ...any) { + if o.IsJSON() { + return + } + msg := sanitizeOutputText(fmt.Sprintf(format, args...)) + o.successColor.Fprintf(o.writer, "(✓) Done: %s\n", msg) +} + +// Warning prints a warning message prefixed with a yellow exclamation mark. +// Warnings are written to ErrWriter in both default and JSON mode so they are +// visible even when stdout is piped through a JSON consumer. +func (o *Output) Warning(format string, args ...any) { + msg := fmt.Sprintf(format, args...) + if o.IsJSON() { + // In JSON mode emit a structured warning to stderr. + _ = json.NewEncoder(o.errWriter).Encode(map[string]string{ + "level": "warning", + "message": msg, + }) + return + } + o.warningColor.Fprintf(o.errWriter, "(!) Warning: %s\n", sanitizeOutputText(msg)) +} + +// Error prints an error message prefixed with a red cross. +// Errors are always written to ErrWriter. +func (o *Output) Error(format string, args ...any) { + msg := fmt.Sprintf(format, args...) + if o.IsJSON() { + _ = json.NewEncoder(o.errWriter).Encode(map[string]string{ + "level": "error", + "message": msg, + }) + return + } + o.errorColor.Fprintf(o.errWriter, "(✗) Error: %s\n", sanitizeOutputText(msg)) +} + +// Info prints an informational message prefixed with an info symbol. +// In JSON mode the call is a no-op (use [Output.JSON] for structured data). +func (o *Output) Info(format string, args ...any) { + if o.IsJSON() { + return + } + msg := sanitizeOutputText(fmt.Sprintf(format, args...)) + o.infoColor.Fprintf(o.writer, "(i) %s\n", msg) +} + +// Message prints an undecorated message to stdout. +// In JSON mode the call is a no-op. +func (o *Output) Message(format string, args ...any) { + if o.IsJSON() { + return + } + msg := sanitizeOutputText(fmt.Sprintf(format, args...)) + fmt.Fprintln(o.writer, msg) +} + +// JSON writes data as a pretty-printed JSON object to stdout. +// It is active in all output modes so callers can unconditionally emit +// structured payloads (in default mode the JSON is still human-readable). +func (o *Output) JSON(data any) error { + enc := json.NewEncoder(o.writer) + enc.SetIndent("", " ") + if err := enc.Encode(data); err != nil { + return fmt.Errorf("output: failed to encode JSON: %w", err) + } + return nil +} + +// Table prints a formatted text table with headers and rows. +// In JSON mode the table is emitted as a JSON array of objects instead. +// +// headers defines the column names. Each row is a slice of cell values +// with the same length as headers. Rows with fewer cells are padded with +// empty strings; extra cells are silently ignored. +func (o *Output) Table(headers []string, rows [][]string) { + if len(headers) == 0 { + return + } + + if o.IsJSON() { + o.tableJSON(headers, rows) + return + } + + o.tableText(headers, rows) +} + +// tableJSON emits the table as a JSON array of objects keyed by header name. +func (o *Output) tableJSON(headers []string, rows [][]string) { + out := make([]map[string]string, 0, len(rows)) + for _, row := range rows { + obj := make(map[string]string, len(headers)) + for i, h := range headers { + if i < len(row) { + obj[h] = row[i] + } else { + obj[h] = "" + } + } + out = append(out, obj) + } + _ = o.JSON(out) +} + +// tableText renders an aligned text table with a header separator. +func (o *Output) tableText(headers []string, rows [][]string) { + // Calculate column widths. + widths := make([]int, len(headers)) + for i, h := range headers { + widths[i] = len(h) + } + for _, row := range rows { + for i := range headers { + if i < len(row) && len(row[i]) > widths[i] { + widths[i] = len(row[i]) + } + } + } + + // Print header row. + for i, h := range headers { + if i > 0 { + fmt.Fprint(o.writer, " ") + } + o.headerColor.Fprintf(o.writer, "%-*s", widths[i], h) + } + fmt.Fprintln(o.writer) + + // Print separator. + for i, w := range widths { + if i > 0 { + fmt.Fprint(o.writer, " ") + } + fmt.Fprint(o.writer, strings.Repeat("─", w)) + } + fmt.Fprintln(o.writer) + + // Print data rows. + for _, row := range rows { + for i := range headers { + if i > 0 { + fmt.Fprint(o.writer, " ") + } + cell := "" + if i < len(row) { + cell = sanitizeOutputText(row[i]) + } + fmt.Fprintf(o.writer, "%-*s", widths[i], cell) + } + fmt.Fprintln(o.writer) + } +} + +// sanitizeOutputText replaces CR, LF, and other ASCII control characters +// (except TAB) with a space so that untrusted values embedded in text-mode +// output cannot forge log lines or inject terminal escape sequences. +// JSON-mode output is NOT sanitized here because json.Encoder already +// escapes control characters in string values. +func sanitizeOutputText(s string) string { + return strings.Map(func(r rune) rune { + if r == '\t' { + return r + } + if r < 0x20 || r == 0x7F { + return ' ' + } + return r + }, s) +} diff --git a/cli/azd/pkg/azdext/output_test.go b/cli/azd/pkg/azdext/output_test.go new file mode 100644 index 00000000000..4a31fe7ccae --- /dev/null +++ b/cli/azd/pkg/azdext/output_test.go @@ -0,0 +1,358 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "bytes" + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// ParseOutputFormat +// --------------------------------------------------------------------------- + +func TestParseOutputFormat(t *testing.T) { + tests := []struct { + name string + input string + expected OutputFormat + expectErr bool + }{ + {name: "default string", input: "default", expected: OutputFormatDefault}, + {name: "empty string", input: "", expected: OutputFormatDefault}, + {name: "json lowercase", input: "json", expected: OutputFormatJSON}, + {name: "JSON uppercase", input: "JSON", expected: OutputFormatJSON}, + {name: "Json mixed case", input: "Json", expected: OutputFormatJSON}, + {name: "DEFAULT uppercase", input: "DEFAULT", expected: OutputFormatDefault}, + {name: "invalid format", input: "xml", expected: OutputFormatDefault, expectErr: true}, + {name: "invalid format yaml", input: "yaml", expected: OutputFormatDefault, expectErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseOutputFormat(tt.input) + if tt.expectErr { + require.Error(t, err) + require.Contains(t, err.Error(), "invalid output format") + } else { + require.NoError(t, err) + } + require.Equal(t, tt.expected, got) + }) + } +} + +// --------------------------------------------------------------------------- +// NewOutput defaults +// --------------------------------------------------------------------------- + +func TestNewOutput_DefaultWriters(t *testing.T) { + out := NewOutput(OutputOptions{}) + require.NotNil(t, out) + // Default format should be "default" (zero-value of OutputFormat). + require.False(t, out.IsJSON()) +} + +func TestNewOutput_JSONMode(t *testing.T) { + out := NewOutput(OutputOptions{Format: OutputFormatJSON}) + require.True(t, out.IsJSON()) +} + +// --------------------------------------------------------------------------- +// Success +// --------------------------------------------------------------------------- + +func TestOutput_Success_DefaultFormat(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + out.Success("deployed %s", "myapp") + + // Should contain the message text (color codes may wrap it). + require.Contains(t, buf.String(), "Done: deployed myapp") +} + +func TestOutput_Success_JSONFormat_IsNoop(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf, Format: OutputFormatJSON}) + + out.Success("should not appear") + + require.Empty(t, buf.String()) +} + +// --------------------------------------------------------------------------- +// Warning +// --------------------------------------------------------------------------- + +func TestOutput_Warning_DefaultFormat(t *testing.T) { + var errBuf bytes.Buffer + out := NewOutput(OutputOptions{ErrWriter: &errBuf}) + + out.Warning("deprecated %s", "v1") + + require.Contains(t, errBuf.String(), "Warning: deprecated v1") +} + +func TestOutput_Warning_JSONFormat(t *testing.T) { + var errBuf bytes.Buffer + out := NewOutput(OutputOptions{ErrWriter: &errBuf, Format: OutputFormatJSON}) + + out.Warning("api deprecated") + + var parsed map[string]string + err := json.Unmarshal(errBuf.Bytes(), &parsed) + require.NoError(t, err) + require.Equal(t, "warning", parsed["level"]) + require.Equal(t, "api deprecated", parsed["message"]) +} + +// --------------------------------------------------------------------------- +// Error +// --------------------------------------------------------------------------- + +func TestOutput_Error_DefaultFormat(t *testing.T) { + var errBuf bytes.Buffer + out := NewOutput(OutputOptions{ErrWriter: &errBuf}) + + out.Error("connection failed: %s", "timeout") + + require.Contains(t, errBuf.String(), "Error: connection failed: timeout") +} + +func TestOutput_Error_JSONFormat(t *testing.T) { + var errBuf bytes.Buffer + out := NewOutput(OutputOptions{ErrWriter: &errBuf, Format: OutputFormatJSON}) + + out.Error("disk full") + + var parsed map[string]string + err := json.Unmarshal(errBuf.Bytes(), &parsed) + require.NoError(t, err) + require.Equal(t, "error", parsed["level"]) + require.Equal(t, "disk full", parsed["message"]) +} + +// --------------------------------------------------------------------------- +// Info +// --------------------------------------------------------------------------- + +func TestOutput_Info_DefaultFormat(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + out.Info("fetching %d items", 5) + + require.Contains(t, buf.String(), "fetching 5 items") +} + +func TestOutput_Info_JSONFormat_IsNoop(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf, Format: OutputFormatJSON}) + + out.Info("hidden") + + require.Empty(t, buf.String()) +} + +// --------------------------------------------------------------------------- +// Message +// --------------------------------------------------------------------------- + +func TestOutput_Message_DefaultFormat(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + out.Message("plain text %d", 42) + + require.Equal(t, "plain text 42\n", buf.String()) +} + +func TestOutput_Message_JSONFormat_IsNoop(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf, Format: OutputFormatJSON}) + + out.Message("should not appear") + + require.Empty(t, buf.String()) +} + +// --------------------------------------------------------------------------- +// JSON +// --------------------------------------------------------------------------- + +func TestOutput_JSON_Struct(t *testing.T) { + type result struct { + Name string `json:"name"` + Count int `json:"count"` + } + + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + err := out.JSON(result{Name: "test", Count: 7}) + require.NoError(t, err) + + var decoded result + require.NoError(t, json.Unmarshal(buf.Bytes(), &decoded)) + require.Equal(t, "test", decoded.Name) + require.Equal(t, 7, decoded.Count) +} + +func TestOutput_JSON_Map(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + err := out.JSON(map[string]string{"key": "value"}) + require.NoError(t, err) + + var decoded map[string]string + require.NoError(t, json.Unmarshal(buf.Bytes(), &decoded)) + require.Equal(t, "value", decoded["key"]) +} + +func TestOutput_JSON_Unmarshalable(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + err := out.JSON(make(chan int)) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to encode JSON") +} + +func TestOutput_JSON_PrettyPrinted(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + err := out.JSON(map[string]int{"a": 1}) + require.NoError(t, err) + + // Verify indentation is present (pretty-printed). + require.Contains(t, buf.String(), " ") +} + +// --------------------------------------------------------------------------- +// Table — default format +// --------------------------------------------------------------------------- + +func TestOutput_Table_DefaultFormat(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + headers := []string{"Name", "Status"} + rows := [][]string{ + {"api", "running"}, + {"web", "stopped"}, + } + + out.Table(headers, rows) + + text := buf.String() + require.Contains(t, text, "Name") + require.Contains(t, text, "Status") + require.Contains(t, text, "api") + require.Contains(t, text, "running") + require.Contains(t, text, "web") + require.Contains(t, text, "stopped") + + // Separator line should be present. + require.Contains(t, text, "─") +} + +func TestOutput_Table_EmptyHeaders(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + out.Table(nil, [][]string{{"a"}}) + + require.Empty(t, buf.String()) +} + +func TestOutput_Table_EmptyRows(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + out.Table([]string{"Name"}, nil) + + // Header + separator should still be printed. + text := buf.String() + require.Contains(t, text, "Name") + require.Contains(t, text, "─") +} + +func TestOutput_Table_ShortRow(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + // Row has fewer cells than headers — should pad with empty strings. + out.Table([]string{"A", "B", "C"}, [][]string{{"only-a"}}) + + text := buf.String() + require.Contains(t, text, "only-a") + // No panic from short row. + lines := strings.Split(strings.TrimSpace(text), "\n") + require.Len(t, lines, 3) // header + separator + 1 data row +} + +func TestOutput_Table_ColumnAlignment(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + headers := []string{"ID", "LongerName"} + rows := [][]string{ + {"1", "short"}, + {"2", "a-much-longer-value"}, + } + + out.Table(headers, rows) + + lines := strings.Split(strings.TrimSpace(buf.String()), "\n") + require.GreaterOrEqual(t, len(lines), 3) + + // All separator dashes should align with header width. + sepLine := lines[1] + require.NotEmpty(t, sepLine) +} + +// --------------------------------------------------------------------------- +// Table — JSON format +// --------------------------------------------------------------------------- + +func TestOutput_Table_JSONFormat(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf, Format: OutputFormatJSON}) + + headers := []string{"Service", "Port"} + rows := [][]string{ + {"api", "8080"}, + {"web", "3000"}, + } + + out.Table(headers, rows) + + var decoded []map[string]string + require.NoError(t, json.Unmarshal(buf.Bytes(), &decoded)) + require.Len(t, decoded, 2) + require.Equal(t, "api", decoded[0]["Service"]) + require.Equal(t, "8080", decoded[0]["Port"]) + require.Equal(t, "web", decoded[1]["Service"]) + require.Equal(t, "3000", decoded[1]["Port"]) +} + +func TestOutput_Table_JSONFormat_ShortRow(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf, Format: OutputFormatJSON}) + + out.Table([]string{"A", "B"}, [][]string{{"only-a"}}) + + var decoded []map[string]string + require.NoError(t, json.Unmarshal(buf.Bytes(), &decoded)) + require.Len(t, decoded, 1) + require.Equal(t, "only-a", decoded[0]["A"]) + require.Equal(t, "", decoded[0]["B"]) +} diff --git a/cli/azd/pkg/azdext/pagination.go b/cli/azd/pkg/azdext/pagination.go index 6f95645686d..2d9af0fda35 100644 --- a/cli/azd/pkg/azdext/pagination.go +++ b/cli/azd/pkg/azdext/pagination.go @@ -14,10 +14,19 @@ import ( "strings" ) +const ( + // defaultMaxPages is the default upper bound on pages fetched by [Pager.Collect]. + // Individual callers can override this via [PagerOptions.MaxPages]. + // A value of 0 means unlimited (no cap), which is the default for manual + // NextPage iteration. Collect uses this default when MaxPages is unset. + defaultMaxPages = 500 +) + const ( // maxPageResponseSize limits the maximum size of a single page response // body to prevent excessive memory consumption from malicious or - // misconfigured servers. + // misconfigured servers. 10 MB is intentionally above typical Azure list + // payloads while still bounding memory use. maxPageResponseSize int64 = 10 << 20 // 10 MB // maxErrorBodySize limits the size of error response bodies captured @@ -42,8 +51,10 @@ type Pager[T any] struct { client HTTPDoer nextURL string done bool + truncated bool opts PagerOptions originHost string // host of the initial URL for SSRF protection + pageCount int // number of pages fetched so far } // PageResponse is a single page returned by [Pager.NextPage]. @@ -59,6 +70,18 @@ type PageResponse[T any] struct { type PagerOptions struct { // Method overrides the HTTP method used for page requests. Defaults to GET. Method string + + // MaxPages limits the maximum number of pages that [Pager.Collect] will + // fetch. When set to a positive value, Collect stops after fetching that + // many pages. A value of 0 means unlimited (no cap) for manual NextPage + // iteration; Collect applies [defaultMaxPages] when this is 0. + MaxPages int + + // MaxItems limits the maximum total items that [Pager.Collect] will + // accumulate. When the collected items reach this count, Collect stops + // and returns the items gathered so far (truncated to MaxItems). + // A value of 0 means unlimited (no cap). + MaxItems int } // HTTPDoer abstracts the HTTP call so that [ResilientClient] or any @@ -117,6 +140,11 @@ func (p *Pager[T]) More() bool { return !p.done && p.nextURL != "" } +// Truncated reports whether the most recent Collect call stopped early due to MaxPages or MaxItems limits. +func (p *Pager[T]) Truncated() bool { + return p.truncated +} + // NextPage fetches the next page of results. Returns an error if the request // fails, the response is not 2xx, or the body cannot be decoded. // @@ -145,7 +173,7 @@ func (p *Pager[T]) NextPage(ctx context.Context) (*PageResponse[T], error) { return nil, &PaginationError{ StatusCode: resp.StatusCode, URL: p.nextURL, - Body: string(body), + Body: sanitizeErrorBody(string(body)), } } @@ -170,6 +198,9 @@ func (p *Pager[T]) NextPage(ctx context.Context) (*PageResponse[T], error) { p.nextURL = page.NextLink } + // Track page count for MaxPages enforcement in Collect. + p.pageCount++ + return &page, nil } @@ -199,12 +230,23 @@ func (p *Pager[T]) validateNextLink(nextLink string) error { } // Collect is a convenience method that fetches all remaining pages and -// returns all items in a single slice. Use with caution on large result sets. +// returns all items in a single slice. +// +// To prevent unbounded memory growth from runaway pagination, Collect +// enforces [PagerOptions.MaxPages] (defaults to [defaultMaxPages] when +// unset) and [PagerOptions.MaxItems]. When either limit is reached, +// iteration stops and the items collected so far are returned. // // If NextPage returns both page data and an error (e.g. rejected nextLink), // the page data is included in the returned slice before returning the error. func (p *Pager[T]) Collect(ctx context.Context) ([]T, error) { var all []T + p.truncated = false + + maxPages := p.opts.MaxPages + if maxPages <= 0 { + maxPages = defaultMaxPages + } for p.More() { page, err := p.NextPage(ctx) @@ -214,11 +256,28 @@ func (p *Pager[T]) Collect(ctx context.Context) ([]T, error) { if err != nil { return all, err } + + // Enforce MaxItems: truncate and stop if exceeded. + if p.opts.MaxItems > 0 && len(all) >= p.opts.MaxItems { + if len(all) > p.opts.MaxItems { + all = all[:p.opts.MaxItems] + } + p.truncated = true + break + } + + // Enforce MaxPages: stop after collecting the configured number of pages. + if p.pageCount >= maxPages { + p.truncated = true + break + } } return all, nil } +const maxPaginationErrorBodyLen = 1024 + // PaginationError is returned when a page request receives a non-2xx response. type PaginationError struct { StatusCode int @@ -229,6 +288,40 @@ type PaginationError struct { func (e *PaginationError) Error() string { return fmt.Sprintf( "azdext.Pager: page request returned HTTP %d (url=%s)", - e.StatusCode, e.URL, + e.StatusCode, redactURL(e.URL), ) } + +func sanitizeErrorBody(body string) string { + if len(body) > maxPaginationErrorBodyLen { + body = body[:maxPaginationErrorBodyLen] + "...[truncated]" + } + return stripControlChars(body) +} + +func stripControlChars(s string) string { + var b strings.Builder + b.Grow(len(s)) + for _, r := range s { + if r < 0x20 && r != '\t' { + b.WriteRune(' ') + } else if r == 0x7F { + b.WriteRune(' ') + } else { + b.WriteRune(r) + } + } + return b.String() +} + +// redactURL strips query parameters and fragments from a URL to avoid leaking +// tokens, SAS signatures, or other secrets in log/error messages. +func redactURL(rawURL string) string { + u, err := url.Parse(rawURL) + if err != nil { + return "" + } + u.RawQuery = "" + u.Fragment = "" + return u.String() +} diff --git a/cli/azd/pkg/azdext/pagination_test.go b/cli/azd/pkg/azdext/pagination_test.go index 1f8c3537279..c800daa1a07 100644 --- a/cli/azd/pkg/azdext/pagination_test.go +++ b/cli/azd/pkg/azdext/pagination_test.go @@ -7,6 +7,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "io" "net/http" "strings" @@ -481,3 +482,98 @@ func TestPager_CollectWithSSRFError(t *testing.T) { t.Errorf("all = %v, want [a b] (partial results before SSRF error)", all) } } + +func TestPager_CollectTruncatedByMaxPages(t *testing.T) { + t.Parallel() + + page1 := pageJSON([]string{"a"}, fmt.Sprintf("https://example.com/api?page=%d", 2)) + page2 := pageJSON([]string{"b"}, fmt.Sprintf("https://example.com/api?page=%d", 3)) + + doer := &mockDoer{ + responses: []*doerResponse{ + {resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(page1)), + Header: http.Header{}, + }}, + {resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(page2)), + Header: http.Header{}, + }}, + }, + } + + pager := NewPager[string](doer, "https://example.com/api?page=1", &PagerOptions{MaxPages: 1}) + all, err := pager.Collect(context.Background()) + if err != nil { + t.Fatalf("Collect failed: %v", err) + } + if len(all) != 1 || all[0] != "a" { + t.Errorf("all = %v, want [a]", all) + } + if !pager.Truncated() { + t.Error("expected pager to be truncated by MaxPages") + } +} + +func TestPager_CollectTruncatedByMaxItems(t *testing.T) { + t.Parallel() + + page1 := pageJSON([]string{"a", "b"}, fmt.Sprintf("https://example.com/api?page=%d", 2)) + page2 := pageJSON([]string{"c", "d"}, "") + + doer := &mockDoer{ + responses: []*doerResponse{ + {resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(page1)), + Header: http.Header{}, + }}, + {resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(page2)), + Header: http.Header{}, + }}, + }, + } + + pager := NewPager[string](doer, "https://example.com/api?page=1", &PagerOptions{MaxItems: 3}) + all, err := pager.Collect(context.Background()) + if err != nil { + t.Fatalf("Collect failed: %v", err) + } + if len(all) != 3 || all[0] != "a" || all[1] != "b" || all[2] != "c" { + t.Errorf("all = %v, want [a b c]", all) + } + if !pager.Truncated() { + t.Error("expected pager to be truncated by MaxItems") + } +} + +func TestPager_CollectNotTruncatedAtEnd(t *testing.T) { + t.Parallel() + + body := pageJSON([]string{"a", "b"}, "") + doer := &mockDoer{ + responses: []*doerResponse{ + {resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{}, + }}, + }, + } + + pager := NewPager[string](doer, "https://example.com/api", nil) + all, err := pager.Collect(context.Background()) + if err != nil { + t.Fatalf("Collect failed: %v", err) + } + if len(all) != 2 || all[0] != "a" || all[1] != "b" { + t.Errorf("all = %v, want [a b]", all) + } + if pager.Truncated() { + t.Error("expected pager to not be truncated at natural end") + } +} diff --git a/cli/azd/pkg/azdext/resilient_http_client.go b/cli/azd/pkg/azdext/resilient_http_client.go index 2916b885334..3c20d5d691a 100644 --- a/cli/azd/pkg/azdext/resilient_http_client.go +++ b/cli/azd/pkg/azdext/resilient_http_client.go @@ -9,6 +9,8 @@ import ( "fmt" "io" "math" + "math/rand/v2" + "net" "net/http" "strconv" "time" @@ -88,6 +90,10 @@ func (o *ResilientClientOptions) defaults() { // tokenProvider may be nil if the caller handles Authorization headers manually. // When non-nil, the client automatically injects a Bearer token using scopes // resolved from the request URL via the [ScopeDetector]. +// +// The client uses a safe transport that validates resolved IP addresses at +// connection time, mitigating DNS rebinding TOCTOU attacks where a hostname +// resolves to a safe IP at check time but to a private IP at connect time. func NewResilientClient(tokenProvider azcore.TokenCredential, opts *ResilientClientOptions) *ResilientClient { if opts == nil { opts = &ResilientClientOptions{} @@ -97,7 +103,7 @@ func NewResilientClient(tokenProvider azcore.TokenCredential, opts *ResilientCli transport := opts.Transport if transport == nil { - transport = http.DefaultTransport + transport = ssrfSafeTransport() } sd := opts.ScopeDetector @@ -107,8 +113,9 @@ func NewResilientClient(tokenProvider azcore.TokenCredential, opts *ResilientCli return &ResilientClient{ httpClient: &http.Client{ - Transport: transport, - Timeout: opts.Timeout, + Transport: transport, + Timeout: opts.Timeout, + CheckRedirect: SSRFSafeRedirect, }, tokenProvider: tokenProvider, scopeDetector: sd, @@ -116,6 +123,49 @@ func NewResilientClient(tokenProvider azcore.TokenCredential, opts *ResilientCli } } +// ssrfSafeTransport returns an [http.RoundTripper] that validates every +// resolved IP address at dial time against SSRF block lists. This closes the +// DNS-rebinding TOCTOU gap where a hostname passes the pre-request SSRF check +// (resolves to a public IP) but is re-resolved by the dialer to a private IP. +// +// The dialer uses the default SSRF guard (metadata + private networks blocked). +func ssrfSafeTransport() http.RoundTripper { + guard := DefaultSSRFGuard() + dialer := &net.Dialer{Timeout: 30 * time.Second} + + return &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, fmt.Errorf("azdext: invalid address %q: %w", addr, err) + } + + // Resolve and validate each IP before connecting. + ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, fmt.Errorf("azdext: DNS resolution failed for %s (fail-closed): %w", host, err) + } + + for _, ipAddr := range ips { + if reason, detail, blocked := ssrfCheckIP( + ipAddr.IP, host, guard.blockedCIDRs, guard.blockPrivate, + ); blocked { + return nil, fmt.Errorf( + "azdext: SSRF dial blocked: %s: %s (host=%s)", reason, detail, host) + } + } + + // Connect to the first resolved IP (already validated). + return dialer.DialContext(ctx, network, net.JoinHostPort(ips[0].IP.String(), port)) + }, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } +} + // Do executes an HTTP request with retry logic and optional bearer-token injection. // // body may be nil for requests without a body (GET, DELETE). @@ -127,6 +177,14 @@ func (rc *ResilientClient) Do(ctx context.Context, method, url string, body io.R return nil, errors.New("azdext.ResilientClient.Do: context must not be nil") } + if body != nil && rc.opts.MaxRetries > 0 { + if _, ok := body.(io.ReadSeeker); !ok { + return nil, errors.New( + "azdext.ResilientClient.Do: request body does not implement io.ReadSeeker; " + + "retries require a seekable body (use bytes.NewReader or strings.NewReader)") + } + } + var lastErr error var retryAfterOverride time.Duration @@ -147,14 +205,9 @@ func (rc *ResilientClient) Do(ctx context.Context, method, url string, body io.R case <-time.After(delay): } - // Reset body for retry — require io.ReadSeeker for non-nil bodies. + // Reset body for retry. if body != nil { - seeker, ok := body.(io.ReadSeeker) - if !ok { - return nil, errors.New( - "azdext.ResilientClient.Do: request body does not implement io.ReadSeeker; " + - "retries require a seekable body (use bytes.NewReader or strings.NewReader)") - } + seeker := body.(io.ReadSeeker) if _, err := seeker.Seek(0, io.SeekStart); err != nil { return nil, fmt.Errorf("azdext.ResilientClient.Do: failed to reset request body: %w", err) } @@ -239,8 +292,9 @@ func (rc *ResilientClient) backoff(attempt int) time.Duration { if delay > rc.opts.MaxDelay { delay = rc.opts.MaxDelay } + jitter := 0.5 + rand.Float64()*0.5 - return delay + return time.Duration(float64(delay) * jitter) } // isRetryable returns true for status codes that indicate a transient failure. @@ -292,6 +346,12 @@ func retryAfterFromResponse(resp *http.Response) time.Duration { } if n, _ := strconv.Atoi(v); n > 0 { + // Cap parsed value before multiplication to prevent integer overflow + // (a crafted Retry-After header could wrap int64, bypassing maxRetryAfterDuration). + maxN := int(maxRetryAfterDuration / rh.units) + if n > maxN { + return maxRetryAfterDuration + } return time.Duration(n) * rh.units } diff --git a/cli/azd/pkg/azdext/resilient_http_client_test.go b/cli/azd/pkg/azdext/resilient_http_client_test.go index 4ffd67af103..c45411f8162 100644 --- a/cli/azd/pkg/azdext/resilient_http_client_test.go +++ b/cli/azd/pkg/azdext/resilient_http_client_test.go @@ -530,9 +530,9 @@ func TestResilientClient_NonSeekableBodyRetryError(t *testing.T) { t.Errorf("error = %q, want mention of io.ReadSeeker", err.Error()) } - // Should have made exactly 1 attempt (first gets 503 → retry → fail on body check). - if attempts != 1 { - t.Errorf("attempts = %d, want 1 (fail before second attempt)", attempts) + // Should fail before first attempt because seekability is validated up front. + if attempts != 0 { + t.Errorf("attempts = %d, want 0 (fail before first attempt)", attempts) } } @@ -609,14 +609,81 @@ func TestResilientClient_RetryAfterCapped(t *testing.T) { t.Errorf("maxRetryAfterDuration = %v, should be <= 5m", maxRetryAfterDuration) } - // A large Retry-After value should be capped in Do(). + // A large Retry-After value should be capped at parse time to prevent + // integer overflow (crafted values could wrap int64 and bypass the cap in Do). h := http.Header{} h.Set("retry-after", "999999") resp := &http.Response{Header: h} got := retryAfterFromResponse(resp) - // retryAfterFromResponse itself doesn't cap (pure parser), but Do() caps it. - if got != 999999*time.Second { - t.Errorf("retryAfterFromResponse() = %v, want %v (capping happens in Do)", got, 999999*time.Second) + // retryAfterFromResponse now caps values to maxRetryAfterDuration to prevent overflow. + if got != maxRetryAfterDuration { + t.Errorf("retryAfterFromResponse() = %v, want %v (capped at parse time)", got, maxRetryAfterDuration) + } +} + +func TestResilientClient_BackoffJitter(t *testing.T) { + t.Parallel() + rc := NewResilientClient(nil, &ResilientClientOptions{InitialDelay: 100 * time.Millisecond, MaxDelay: 10 * time.Second}) + seen := make(map[time.Duration]bool) + for range 20 { + d := rc.backoff(1) + seen[d] = true + if d < 50*time.Millisecond || d >= 100*time.Millisecond { + t.Errorf("backoff(1) = %v, want in [50ms, 100ms)", d) + } + } + if len(seen) < 2 { + t.Error("backoff jitter produced identical values across 20 calls") + } +} + +func TestResilientClient_NonSeekableBodyFailsFast(t *testing.T) { + t.Parallel() + var attempts int + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + attempts++ + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("ok")), + Header: http.Header{}, + }, nil + }) + rc := NewResilientClient(nil, &ResilientClientOptions{Transport: transport, MaxRetries: 2, InitialDelay: time.Millisecond}) + body := io.NopCloser(strings.NewReader("payload")) + _, err := rc.Do(context.Background(), http.MethodPost, "https://example.com/api", body) + if err == nil { + t.Fatal("expected error for non-seekable body with retries enabled") + } + if !strings.Contains(err.Error(), "io.ReadSeeker") { + t.Errorf("error = %q, want mention of io.ReadSeeker", err.Error()) + } + if attempts != 0 { + t.Errorf("attempts = %d, want 0", attempts) + } +} + +func TestResilientClient_RetryAfterCappedInDo(t *testing.T) { + t.Parallel() + var attempts int + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + attempts++ + h := http.Header{} + h.Set("retry-after", "999999") + return &http.Response{ + StatusCode: http.StatusTooManyRequests, + Body: io.NopCloser(strings.NewReader("throttled")), + Header: h, + }, nil + }) + rc := NewResilientClient(nil, &ResilientClientOptions{Transport: transport, MaxRetries: 1, InitialDelay: time.Millisecond}) + ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) + defer cancel() + _, err := rc.Do(ctx, http.MethodGet, "https://example.com/api", nil) + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("expected context.DeadlineExceeded, got: %v", err) + } + if attempts != 1 { + t.Errorf("attempts = %d, want 1", attempts) } } diff --git a/cli/azd/pkg/azdext/security_validation.go b/cli/azd/pkg/azdext/security_validation.go new file mode 100644 index 00000000000..5d5e5dd1c7e --- /dev/null +++ b/cli/azd/pkg/azdext/security_validation.go @@ -0,0 +1,249 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "fmt" + "os" + "regexp" + "strings" +) + +// --------------------------------------------------------------------------- +// Validation error type +// --------------------------------------------------------------------------- + +// ValidationError describes a failed input validation with structured context. +type ValidationError struct { + // Field is the logical name of the input being validated (e.g. "service_name"). + Field string + + // Value is the rejected input value. For security-sensitive inputs the value + // may be truncated or redacted by the caller before constructing the error. + Value string + + // Rule is a short machine-readable tag for the violated constraint + // (e.g. "format", "length", "characters"). + Rule string + + // Message is a human-readable explanation suitable for end-user display. + Message string +} + +func (e *ValidationError) Error() string { + return fmt.Sprintf("azdext.Validate: %s %q: %s (%s)", e.Field, e.Value, e.Message, e.Rule) +} + +// --------------------------------------------------------------------------- +// Service name validation +// --------------------------------------------------------------------------- + +// serviceNameRe matches DNS-safe service names: +// - starts with alphanumeric +// - contains only alphanumeric, '.', '_', '-' +// - 1–63 characters total (DNS label limit) +var serviceNameRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._-]{0,62}$`) + +// ValidateServiceName checks that name is a valid DNS-safe service identifier. +// +// Rules: +// - Must start with an alphanumeric character. +// - May contain alphanumeric characters, '.', '_', and '-'. +// - Must be 1–63 characters (DNS label length limit per RFC 1035). +// +// Returns a [*ValidationError] on failure. +func ValidateServiceName(name string) error { + if name == "" { + return &ValidationError{ + Field: "service_name", + Value: "", + Rule: "required", + Message: "service name must not be empty", + } + } + + if !serviceNameRe.MatchString(name) { + return &ValidationError{ + Field: "service_name", + Value: truncateValue(name, 64), + Rule: "format", + Message: "service name must start with alphanumeric and contain only [a-zA-Z0-9._-], max 63 chars", + } + } + + return nil +} + +// --------------------------------------------------------------------------- +// Hostname validation +// --------------------------------------------------------------------------- + +// hostnameRe matches RFC 952/1123 hostnames: +// - labels separated by '.' +// - each label: starts and ends with alphanumeric, may contain '-', 1–63 chars +// - total length <= 253 characters +var hostnameRe = regexp.MustCompile( + `^[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?` + + `(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$`, +) + +// ValidateHostname checks that hostname conforms to RFC 952/1123. +// +// Rules: +// - Each label must start and end with an alphanumeric character. +// - Labels may contain alphanumeric characters and '-'. +// - Each label is 1–63 characters. +// - Total hostname length is ≤ 253 characters. +// +// Returns a [*ValidationError] on failure. +func ValidateHostname(hostname string) error { + if hostname == "" { + return &ValidationError{ + Field: "hostname", + Value: "", + Rule: "required", + Message: "hostname must not be empty", + } + } + + if len(hostname) > 253 { + return &ValidationError{ + Field: "hostname", + Value: truncateValue(hostname, 64), + Rule: "length", + Message: "hostname must not exceed 253 characters", + } + } + + if !hostnameRe.MatchString(hostname) { + return &ValidationError{ + Field: "hostname", + Value: truncateValue(hostname, 64), + Rule: "format", + Message: "hostname must conform to RFC 952/1123 " + + "(labels: alphanumeric start/end, may contain '-', 1-63 chars each)", + } + } + + return nil +} + +// --------------------------------------------------------------------------- +// Script name validation +// --------------------------------------------------------------------------- + +// shellMetacharacters contains characters that have special meaning in common +// shells (bash, sh, zsh, cmd, PowerShell). A script name containing any of +// these is rejected to prevent command injection. +const shellMetacharacters = ";|&`$(){}[]<>!#~*?\"\\'%\n\r\x00" + +// ValidateScriptName checks that name does not contain shell metacharacters +// or path traversal sequences that could lead to command injection. +// +// Rejected patterns: +// - Shell metacharacters: ; | & ` $ ( ) { } [ ] < > ! # ~ * ? " ' \ % +// - Path traversal: ".." +// - Null bytes and newlines +// - Empty names +// +// Returns a [*ValidationError] on failure. +func ValidateScriptName(name string) error { + if name == "" { + return &ValidationError{ + Field: "script_name", + Value: "", + Rule: "required", + Message: "script name must not be empty", + } + } + + if strings.Contains(name, "..") { + return &ValidationError{ + Field: "script_name", + Value: truncateValue(name, 64), + Rule: "traversal", + Message: "script name must not contain path traversal sequences (..)", + } + } + + if idx := strings.IndexAny(name, shellMetacharacters); idx >= 0 { + return &ValidationError{ + Field: "script_name", + Value: truncateValue(name, 64), + Rule: "characters", + Message: fmt.Sprintf("script name contains forbidden shell metacharacter at position %d", idx), + } + } + + return nil +} + +// --------------------------------------------------------------------------- +// Container environment detection +// --------------------------------------------------------------------------- + +// containerEnvVars maps environment variables to the container runtime they indicate. +var containerEnvVars = map[string]string{ + "CODESPACES": "codespaces", + "KUBERNETES_SERVICE_HOST": "kubernetes", + "REMOTE_CONTAINERS": "devcontainer", + "REMOTE_CONTAINERS_IPC": "devcontainer", +} + +// IsContainerEnvironment reports whether the current process is running inside +// a container environment. It checks for: +// - GitHub Codespaces (CODESPACES env var) +// - Kubernetes (KUBERNETES_SERVICE_HOST env var) +// - VS Code Dev Containers (REMOTE_CONTAINERS / REMOTE_CONTAINERS_IPC env vars) +// - Docker (/.dockerenv file) +// +// The detection is best-effort and does not guarantee accuracy in all +// environments. It is intended for feature gating and diagnostics, not +// security decisions. +func IsContainerEnvironment() bool { + // Check well-known environment variables. + for envKey := range containerEnvVars { + if v := os.Getenv(envKey); v != "" { + return true + } + } + + // Check for Docker's marker file. + if _, err := os.Stat("/.dockerenv"); err == nil { + return true + } + + return false +} + +// ContainerRuntime returns the detected container runtime name, or an empty +// string if no container environment is detected. +// +// Possible return values: "codespaces", "kubernetes", "devcontainer", "docker", "". +func ContainerRuntime() string { + for envKey, runtime := range containerEnvVars { + if v := os.Getenv(envKey); v != "" { + return runtime + } + } + + if _, err := os.Stat("/.dockerenv"); err == nil { + return "docker" + } + + return "" +} + +// --------------------------------------------------------------------------- +// Internal helpers +// --------------------------------------------------------------------------- + +// truncateValue truncates s to maxLen characters for safe inclusion in error +// messages. If truncated, an ellipsis is appended. +func truncateValue(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} diff --git a/cli/azd/pkg/azdext/security_validation_test.go b/cli/azd/pkg/azdext/security_validation_test.go new file mode 100644 index 00000000000..cd4165e67dd --- /dev/null +++ b/cli/azd/pkg/azdext/security_validation_test.go @@ -0,0 +1,354 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "errors" + "os" + "strings" + "testing" +) + +// --------------------------------------------------------------------------- +// ValidateServiceName +// --------------------------------------------------------------------------- + +func TestValidateServiceName_Valid(t *testing.T) { + valid := []string{ + "web", + "my-service", + "my_service", + "my.service", + "A1", + "a", + "service-v2.0", + "a23456789012345678901234567890123456789012345678901234567890123", // 63 chars + } + for _, name := range valid { + if err := ValidateServiceName(name); err != nil { + t.Errorf("ValidateServiceName(%q) = %v, want nil", name, err) + } + } +} + +func TestValidateServiceName_Invalid(t *testing.T) { + tests := []struct { + name string + rule string + }{ + {"", "required"}, + {"-starts-with-dash", "format"}, + {".starts-with-dot", "format"}, + {"_starts-with-underscore", "format"}, + {"has space", "format"}, + {"has;semicolon", "format"}, + {"has|pipe", "format"}, + {"has$dollar", "format"}, + {"has/slash", "format"}, + {"has@at", "format"}, + // 64 chars (too long) + {"a234567890123456789012345678901234567890123456789012345678901234", "format"}, + } + for _, tc := range tests { + err := ValidateServiceName(tc.name) + if err == nil { + t.Errorf("ValidateServiceName(%q) = nil, want error", tc.name) + continue + } + var ve *ValidationError + if ok := isValidationError(err, &ve); !ok { + t.Errorf("ValidateServiceName(%q) returned %T, want *ValidationError", tc.name, err) + continue + } + if ve.Rule != tc.rule { + t.Errorf("ValidateServiceName(%q).Rule = %q, want %q", tc.name, ve.Rule, tc.rule) + } + if ve.Field != "service_name" { + t.Errorf("ValidateServiceName(%q).Field = %q, want %q", tc.name, ve.Field, "service_name") + } + } +} + +// --------------------------------------------------------------------------- +// ValidateHostname +// --------------------------------------------------------------------------- + +func TestValidateHostname_Valid(t *testing.T) { + valid := []string{ + "example.com", + "sub.example.com", + "a.b.c.d.example.com", + "my-host", + "a", + "1", + "192-168-1-1.nip.io", + "xn--nxasmq6b.example.com", // punycode + } + for _, h := range valid { + if err := ValidateHostname(h); err != nil { + t.Errorf("ValidateHostname(%q) = %v, want nil", h, err) + } + } +} + +func TestValidateHostname_Invalid(t *testing.T) { + tests := []struct { + hostname string + rule string + }{ + {"", "required"}, + {"-starts-with-dash.com", "format"}, + {"ends-with-dash-.com", "format"}, + {"has space.com", "format"}, + {"has_underscore.com", "format"}, + {"has..double-dot.com", "format"}, + {".starts-with-dot.com", "format"}, + {"has;semicolon.com", "format"}, + // 254 chars (too long) + {strings.Repeat("a", 254), "length"}, + } + for _, tc := range tests { + err := ValidateHostname(tc.hostname) + if err == nil { + t.Errorf("ValidateHostname(%q) = nil, want error", tc.hostname) + continue + } + var ve *ValidationError + if ok := isValidationError(err, &ve); !ok { + t.Errorf("ValidateHostname(%q) returned %T, want *ValidationError", tc.hostname, err) + continue + } + if ve.Rule != tc.rule { + t.Errorf("ValidateHostname(%q).Rule = %q, want %q", tc.hostname, ve.Rule, tc.rule) + } + } +} + +func TestValidateHostname_LabelLength(t *testing.T) { + // Each label max 63 chars. A 64-char label should fail. + longLabel := strings.Repeat("a", 64) + ".com" + if err := ValidateHostname(longLabel); err == nil { + t.Errorf("ValidateHostname with 64-char label = nil, want error") + } + + // 63-char label should succeed. + okLabel := strings.Repeat("a", 63) + ".com" + if err := ValidateHostname(okLabel); err != nil { + t.Errorf("ValidateHostname with 63-char label = %v, want nil", err) + } +} + +// --------------------------------------------------------------------------- +// ValidateScriptName +// --------------------------------------------------------------------------- + +func TestValidateScriptName_Valid(t *testing.T) { + valid := []string{ + "script.sh", + "my-script.py", + "build_project.ps1", + "run", + "deploy-v2.sh", + "test.cmd", + "start server", // spaces are OK in script names (not metacharacters) + } + for _, name := range valid { + if err := ValidateScriptName(name); err != nil { + t.Errorf("ValidateScriptName(%q) = %v, want nil", name, err) + } + } +} + +func TestValidateScriptName_ShellMetacharacters(t *testing.T) { + dangerous := []struct { + name string + desc string + }{ + {"script;rm -rf /", "semicolon command chaining"}, + {"script|cat /etc/passwd", "pipe"}, + {"script&background", "ampersand"}, + {"script`whoami`", "backtick command substitution"}, + {"script$(id)", "dollar-paren command substitution"}, + {"script > /dev/null", "output redirect"}, + {"script < /etc/passwd", "input redirect"}, + {"script\nrm -rf /", "newline injection"}, + {"script\x00null", "null byte"}, + {"script'quoted'", "single quote"}, + {"script\"quoted\"", "double quote"}, + {"script\\escaped", "backslash"}, + {"script!history", "exclamation/history expansion"}, + {"script#comment", "hash/comment"}, + {"script~home", "tilde expansion"}, + {"script*glob", "glob star"}, + {"script?glob", "glob question"}, + {"script%env", "percent"}, + {"script(sub)", "open paren"}, + {"script{brace}", "open brace"}, + {"script[bracket]", "open bracket"}, + } + for _, tc := range dangerous { + err := ValidateScriptName(tc.name) + if err == nil { + t.Errorf("ValidateScriptName(%q) = nil, want error (%s)", tc.name, tc.desc) + continue + } + var ve *ValidationError + if ok := isValidationError(err, &ve); !ok { + t.Errorf("ValidateScriptName(%q) returned %T, want *ValidationError", tc.name, err) + continue + } + if ve.Rule != "characters" { + t.Errorf("ValidateScriptName(%q).Rule = %q, want %q (%s)", tc.name, ve.Rule, "characters", tc.desc) + } + } +} + +func TestValidateScriptName_PathTraversal(t *testing.T) { + traversal := []string{ + "../etc/passwd", + "../../secret.sh", + "dir/../../../root.sh", + "..\\windows\\system32", + } + for _, name := range traversal { + err := ValidateScriptName(name) + if err == nil { + t.Errorf("ValidateScriptName(%q) = nil, want error", name) + continue + } + var ve *ValidationError + if ok := isValidationError(err, &ve); !ok { + t.Errorf("ValidateScriptName(%q) returned %T, want *ValidationError", name, err) + continue + } + if ve.Rule != "traversal" { + t.Errorf("ValidateScriptName(%q).Rule = %q, want %q", name, ve.Rule, "traversal") + } + } +} + +func TestValidateScriptName_Empty(t *testing.T) { + err := ValidateScriptName("") + if err == nil { + t.Error("ValidateScriptName(\"\") = nil, want error") + return + } + var ve *ValidationError + if ok := isValidationError(err, &ve); !ok { + t.Errorf("ValidateScriptName(\"\") returned %T, want *ValidationError", err) + return + } + if ve.Rule != "required" { + t.Errorf("ValidateScriptName(\"\").Rule = %q, want %q", ve.Rule, "required") + } +} + +// --------------------------------------------------------------------------- +// IsContainerEnvironment / ContainerRuntime +// --------------------------------------------------------------------------- + +func TestIsContainerEnvironment_EnvVars(t *testing.T) { + tests := []struct { + envKey string + runtime string + }{ + {"CODESPACES", "codespaces"}, + {"KUBERNETES_SERVICE_HOST", "kubernetes"}, + {"REMOTE_CONTAINERS", "devcontainer"}, + {"REMOTE_CONTAINERS_IPC", "devcontainer"}, + } + for _, tc := range tests { + t.Run(tc.envKey, func(t *testing.T) { + // Ensure the env var is clean before/after. + orig := os.Getenv(tc.envKey) + t.Setenv(tc.envKey, "true") + + if !IsContainerEnvironment() { + t.Errorf("IsContainerEnvironment() = false with %s set", tc.envKey) + } + + rt := ContainerRuntime() + if rt != tc.runtime { + t.Errorf("ContainerRuntime() = %q with %s set, want %q", rt, tc.envKey, tc.runtime) + } + + // Restore and verify negative case. + if orig == "" { + os.Unsetenv(tc.envKey) + } else { + os.Setenv(tc.envKey, orig) + } + }) + } +} + +func TestIsContainerEnvironment_NoContainerEnv(t *testing.T) { + // Clear all container-related env vars. + for envKey := range containerEnvVars { + if v := os.Getenv(envKey); v != "" { + t.Setenv(envKey, "") + os.Unsetenv(envKey) + } + } + + // In CI or local dev without Docker marker, this should return false. + // We can't guarantee /.dockerenv doesn't exist, but in typical test + // environments it won't. + runtime := ContainerRuntime() + // Only assert no env-var-based detection. Docker file detection is + // environment-dependent and not worth mocking here. + _ = runtime +} + +// --------------------------------------------------------------------------- +// ValidationError +// --------------------------------------------------------------------------- + +func TestValidationError_ErrorMessage(t *testing.T) { + err := &ValidationError{ + Field: "test_field", + Value: "bad-value", + Rule: "format", + Message: "value is invalid", + } + + msg := err.Error() + if !strings.Contains(msg, "test_field") { + t.Errorf("error message should contain field name, got: %s", msg) + } + if !strings.Contains(msg, "bad-value") { + t.Errorf("error message should contain value, got: %s", msg) + } + if !strings.Contains(msg, "format") { + t.Errorf("error message should contain rule, got: %s", msg) + } +} + +func TestTruncateValue(t *testing.T) { + tests := []struct { + input string + maxLen int + want string + }{ + {"short", 10, "short"}, + {"exactly10!", 10, "exactly10!"}, + {"this is way too long", 10, "this is wa..."}, + {"", 5, ""}, + } + for _, tc := range tests { + got := truncateValue(tc.input, tc.maxLen) + if got != tc.want { + t.Errorf("truncateValue(%q, %d) = %q, want %q", tc.input, tc.maxLen, got, tc.want) + } + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +// isValidationError is a type assertion helper for testing. +func isValidationError(err error, target **ValidationError) bool { + return errors.As(err, target) +} diff --git a/cli/azd/pkg/azdext/ssrf_common.go b/cli/azd/pkg/azdext/ssrf_common.go new file mode 100644 index 00000000000..91ad8595814 --- /dev/null +++ b/cli/azd/pkg/azdext/ssrf_common.go @@ -0,0 +1,129 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "fmt" + "net" +) + +// ssrfMetadataHosts lists well-known cloud metadata service hostnames/IPs. +var ssrfMetadataHosts = []string{ + "169.254.169.254", + "fd00:ec2::254", + "metadata.google.internal", + "100.100.100.200", +} + +// ssrfBlockedCIDRs lists CIDR blocks for private, loopback, link-local, and +// IPv6 transition mechanism networks. +var ssrfBlockedCIDRs = []string{ + "0.0.0.0/8", // "this" network (reaches loopback on Linux/macOS) + "10.0.0.0/8", // RFC 1918 private + "172.16.0.0/12", // RFC 1918 private + "192.168.0.0/16", // RFC 1918 private + "127.0.0.0/8", // loopback + "100.64.0.0/10", // RFC 6598 shared/CGNAT + "169.254.0.0/16", // IPv4 link-local + "::1/128", // IPv6 loopback + "::/128", // IPv6 unspecified + "fc00::/7", // IPv6 unique local (RFC 4193) + "fe80::/10", // IPv6 link-local + "2002::/16", // 6to4 relay (deprecated RFC 7526) + "2001::/32", // Teredo tunneling (deprecated) + "64:ff9b::/96", // NAT64 well-known prefix (RFC 6052) + "64:ff9b:1::/48", // NAT64 local-use prefix (RFC 8215) +} + +func ssrfCheckIP( + ip net.IP, + originalHost string, + blockedCIDRs []*net.IPNet, + blockPrivate bool, +) (string, string, bool) { + for _, cidr := range blockedCIDRs { + if cidr.Contains(ip) { + return "blocked_ip", fmt.Sprintf("IP %s matches blocked CIDR %s (host: %s)", ip, cidr, originalHost), true + } + } + + if !blockPrivate { + return "", "", false + } + + if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsUnspecified() { + return "private_network", fmt.Sprintf("IP %s is private/loopback/link-local (host: %s)", ip, originalHost), true + } + + if len(ip) != net.IPv6len || ip.To4() != nil { + return "", "", false + } + + if v4 := extractIPv4Compatible(ip); v4 != nil { + for _, cidr := range blockedCIDRs { + if cidr.Contains(v4) { + return "blocked_ip", fmt.Sprintf( + "IP %s (IPv4-compatible %s, CIDR %s) for host %s", + ip, v4, cidr, originalHost, + ), true + } + } + if v4.IsLoopback() || v4.IsPrivate() || v4.IsLinkLocalUnicast() || v4.IsUnspecified() { + return "private_network", fmt.Sprintf( + "IP %s (IPv4-compatible %s, private/loopback) for host %s", + ip, v4, originalHost, + ), true + } + } + + if v4 := extractIPv4Translated(ip); v4 != nil { + for _, cidr := range blockedCIDRs { + if cidr.Contains(v4) { + return "blocked_ip", fmt.Sprintf( + "IP %s (IPv4-translated %s, CIDR %s) for host %s", + ip, v4, cidr, originalHost, + ), true + } + } + if v4.IsLoopback() || v4.IsPrivate() || v4.IsLinkLocalUnicast() || v4.IsUnspecified() { + return "private_network", fmt.Sprintf( + "IP %s (IPv4-translated %s, private/loopback) for host %s", + ip, v4, originalHost, + ), true + } + } + + return "", "", false +} + +// extractIPv4Compatible extracts the embedded IPv4 from an IPv4-compatible +// IPv6 address (::x.x.x.x — first 12 bytes zero, last 4 non-zero). +func extractIPv4Compatible(ip net.IP) net.IP { + for i := 0; i < 12; i++ { + if ip[i] != 0 { + return nil + } + } + if ip[12] == 0 && ip[13] == 0 && ip[14] == 0 && ip[15] == 0 { + return nil + } + return net.IPv4(ip[12], ip[13], ip[14], ip[15]) +} + +// extractIPv4Translated extracts the embedded IPv4 from an IPv4-translated +// IPv6 address (::ffff:0:x.x.x.x — RFC 2765 §4.2.1). +func extractIPv4Translated(ip net.IP) net.IP { + for i := 0; i < 8; i++ { + if ip[i] != 0 { + return nil + } + } + if ip[8] != 0xFF || ip[9] != 0xFF || ip[10] != 0x00 || ip[11] != 0x00 { + return nil + } + if ip[12] == 0 && ip[13] == 0 && ip[14] == 0 && ip[15] == 0 { + return nil + } + return net.IPv4(ip[12], ip[13], ip[14], ip[15]) +} diff --git a/cli/azd/pkg/azdext/ssrf_guard.go b/cli/azd/pkg/azdext/ssrf_guard.go new file mode 100644 index 00000000000..6ef07fae5ff --- /dev/null +++ b/cli/azd/pkg/azdext/ssrf_guard.go @@ -0,0 +1,283 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "fmt" + "net" + "net/url" + "strings" + "sync" +) + +// SSRFGuard validates URLs against Server-Side Request Forgery (SSRF) attack +// patterns. It provides standalone SSRF protection for extension authors who +// need URL validation outside of MCP contexts. +// +// SSRFGuard uses a fluent builder pattern for configuration: +// +// guard := azdext.NewSSRFGuard(). +// BlockMetadataEndpoints(). +// BlockPrivateNetworks(). +// RequireHTTPS() +// +// if err := guard.Check("http://169.254.169.254/metadata"); err != nil { +// // blocked: cloud metadata endpoint +// } +// +// Use [DefaultSSRFGuard] for a preset configuration that blocks metadata +// endpoints, private networks, and requires HTTPS. +// +// SSRFGuard is safe for concurrent use from multiple goroutines. +type SSRFGuard struct { + mu sync.RWMutex + blockMetadata bool + blockPrivate bool + requireHTTPS bool + blockedCIDRs []*net.IPNet + blockedHosts map[string]bool + allowedHosts map[string]bool + // lookupHost is used for DNS resolution; override in tests. + lookupHost func(string) ([]string, error) + // onBlocked is an optional callback invoked when a URL is blocked. + // Parameters: reason (machine-readable tag), detail (human-readable). + onBlocked func(reason, detail string) +} + +// SSRFError describes why a URL was rejected by the [SSRFGuard]. +type SSRFError struct { + // URL is the rejected URL (or a sanitized representation). + URL string + + // Reason is a machine-readable tag for the violation type. + // Values: "blocked_host", "blocked_ip", "private_network", + // "metadata_endpoint", "dns_failure", "https_required", + // "invalid_url", "scheme_blocked". + Reason string + + // Detail is a human-readable explanation. + Detail string +} + +func (e *SSRFError) Error() string { + return fmt.Sprintf("azdext.SSRFGuard: %s: %s (url=%s)", e.Reason, e.Detail, e.URL) +} + +// NewSSRFGuard creates an empty SSRF guard with no active protections. +// Use the builder methods to configure protections, or use [DefaultSSRFGuard] +// for a preset secure configuration. +func NewSSRFGuard() *SSRFGuard { + return &SSRFGuard{ + blockedHosts: make(map[string]bool), + allowedHosts: make(map[string]bool), + lookupHost: net.LookupHost, + } +} + +// DefaultSSRFGuard returns a guard preconfigured with: +// - Cloud metadata endpoint blocking (AWS, Azure, GCP, Alibaba) +// - Private network blocking (RFC 1918, loopback, link-local, CGNAT, IPv6 ULA, +// 6to4, Teredo, NAT64) +// - HTTPS enforcement (except localhost) +// +// This is the recommended starting point for extension authors. +func DefaultSSRFGuard() *SSRFGuard { + return NewSSRFGuard(). + BlockMetadataEndpoints(). + BlockPrivateNetworks(). + RequireHTTPS() +} + +// BlockMetadataEndpoints blocks well-known cloud metadata service endpoints: +// - 169.254.169.254 (AWS, Azure, most cloud providers) +// - fd00:ec2::254 (AWS EC2 IPv6 metadata) +// - metadata.google.internal (GCP) +// - 100.100.100.200 (Alibaba Cloud) +func (g *SSRFGuard) BlockMetadataEndpoints() *SSRFGuard { + g.mu.Lock() + defer g.mu.Unlock() + g.blockMetadata = true + for _, host := range ssrfMetadataHosts { + g.blockedHosts[strings.ToLower(host)] = true + } + return g +} + +// BlockPrivateNetworks blocks RFC 1918 private networks, loopback, link-local, +// CGNAT (RFC 6598), and IPv6 transition mechanisms that can embed private IPv4 +// addresses (6to4, Teredo, NAT64, IPv4-compatible, IPv4-translated). +func (g *SSRFGuard) BlockPrivateNetworks() *SSRFGuard { + g.mu.Lock() + defer g.mu.Unlock() + g.blockPrivate = true + for _, cidr := range ssrfBlockedCIDRs { + _, ipNet, err := net.ParseCIDR(cidr) + if err == nil { + g.blockedCIDRs = append(g.blockedCIDRs, ipNet) + } + } + return g +} + +// RequireHTTPS requires HTTPS for all URLs except localhost and loopback +// addresses. HTTP to localhost/127.0.0.1/[::1] is always permitted for +// local development. +func (g *SSRFGuard) RequireHTTPS() *SSRFGuard { + g.mu.Lock() + defer g.mu.Unlock() + g.requireHTTPS = true + return g +} + +// AllowHost adds hosts to an explicit allowlist. Allowed hosts bypass all +// IP-based and metadata checks. Host names are compared case-insensitively. +// +// Use this sparingly — over-broad allowlists weaken SSRF protection. Prefer +// allowing specific, known-good endpoints rather than wildcards. +func (g *SSRFGuard) AllowHost(hosts ...string) *SSRFGuard { + g.mu.Lock() + defer g.mu.Unlock() + for _, h := range hosts { + g.allowedHosts[strings.ToLower(h)] = true + } + return g +} + +// OnBlocked registers a callback invoked whenever a URL is blocked. This +// enables security audit logging without coupling the guard to a logging +// framework. The callback receives the machine-readable reason tag and a +// human-readable detail string. It must be safe for concurrent invocation. +func (g *SSRFGuard) OnBlocked(fn func(reason, detail string)) *SSRFGuard { + g.mu.Lock() + defer g.mu.Unlock() + g.onBlocked = fn + return g +} + +// Check validates a URL against the guard's SSRF policy. +// +// Validation order: +// 1. Parse the URL and reject non-HTTP(S) schemes. +// 2. If HTTPS is required, reject plain HTTP to non-localhost hosts. +// 3. Skip further checks if the host is explicitly allowed via [AllowHost]. +// 4. Skip further checks for localhost/loopback hosts (local development). +// 5. Reject hosts matching the metadata endpoint blocklist. +// 6. For IP-literal hosts, check directly against blocked CIDRs. +// 7. For hostname hosts, resolve DNS (fail-closed on lookup failure) and +// check all resolved IPs against blocked CIDRs. +// +// For IPv6 addresses, embedded IPv4 (IPv4-compatible, IPv4-mapped, +// IPv4-translated per RFC 2765) is extracted and re-checked against blocked CIDRs. +// +// Returns nil if the URL is allowed, or a [*SSRFError] describing the violation. +func (g *SSRFGuard) Check(rawURL string) error { + g.mu.RLock() + fn := g.onBlocked + ssrfErr := g.checkCore(rawURL) + g.mu.RUnlock() + + if ssrfErr != nil { + if fn != nil { + fn(ssrfErr.Reason, ssrfErr.Detail) + } + return ssrfErr + } + + return nil +} + +// checkCore performs URL validation without acquiring the lock or invoking +// the onBlocked callback. Callers must hold g.mu (at least RLock). +func (g *SSRFGuard) checkCore(rawURL string) *SSRFError { + u, err := url.Parse(rawURL) + if err != nil { + return g.blocked(truncateValue(rawURL, 200), "invalid_url", "URL parsing failed: "+err.Error()) + } + + host := u.Hostname() + + // Step 1: Scheme validation — only http and https permitted. + switch u.Scheme { + case "https": + // Always allowed. + case "http": + if g.requireHTTPS && !isLocalhostHost(host) { + return g.blocked(truncateValue(rawURL, 200), "https_required", "HTTPS is required for non-localhost URLs") + } + default: + return g.blocked(truncateValue(rawURL, 200), "scheme_blocked", + fmt.Sprintf("scheme %q is not allowed (only http and https are permitted)", u.Scheme)) + } + + lowerHost := strings.ToLower(host) + + // Step 2: Explicit allowlist bypass. + if g.allowedHosts[lowerHost] { + return nil + } + + // Step 3: Localhost/loopback bypass — localhost is the developer's own + // machine and is exempt from IP-level SSRF blocking to allow local + // development workflows (e.g. local API servers, proxies, dev tools). + if isLocalhostHost(host) { + return nil + } + + // Step 5: Metadata endpoint check. + if g.blockedHosts[lowerHost] { + return g.blocked(truncateValue(rawURL, 200), "blocked_host", + fmt.Sprintf("host %s is blocked", host)) + } + + // Step 6: IP-based checks. + if ip := net.ParseIP(host); ip != nil { + // Check normalized IP form against blocked hosts — catches IPv4-mapped + // IPv6 forms like ::ffff:169.254.169.254 that bypass string matching. + if normalizedIP := ip.String(); normalizedIP != host { + if g.blockedHosts[strings.ToLower(normalizedIP)] { + return g.blocked(truncateValue(rawURL, 200), "blocked_host", + fmt.Sprintf("host %s is blocked (normalized: %s)", host, normalizedIP)) + } + } + // Direct IP literal — check against blocked ranges. + return g.checkIPForSSRF(ip, host, rawURL) + } + + // Step 7: DNS resolution for hostnames (fail-closed). + addrs, err := g.lookupHost(host) + if err != nil { + return g.blocked(truncateValue(rawURL, 200), "dns_failure", + fmt.Sprintf("DNS resolution failed for %s (fail-closed): %s", host, err.Error())) + } + + for _, addr := range addrs { + if g.blockedHosts[strings.ToLower(addr)] { + return g.blocked(truncateValue(rawURL, 200), "blocked_host", + fmt.Sprintf("host %s resolved to blocked address %s", host, addr)) + } + if ip := net.ParseIP(addr); ip != nil { + if ssrfErr := g.checkIPForSSRF(ip, host, rawURL); ssrfErr != nil { + return ssrfErr + } + } + } + + return nil +} + +// blocked creates an SSRFError. +func (g *SSRFGuard) blocked(urlStr, reason, detail string) *SSRFError { + return &SSRFError{URL: urlStr, Reason: reason, Detail: detail} +} + +// checkIPForSSRF validates an IP address against blocked CIDRs and private +// network categories. It also extracts embedded IPv4 from IPv6 encoding +// variants (IPv4-compatible, IPv4-translated RFC 2765) that Go's net.IP +// methods do not classify. +func (g *SSRFGuard) checkIPForSSRF(ip net.IP, originalHost, rawURL string) *SSRFError { + if reason, detail, isBlocked := ssrfCheckIP(ip, originalHost, g.blockedCIDRs, g.blockPrivate); isBlocked { + return g.blocked(truncateValue(rawURL, 200), reason, detail) + } + return nil +} diff --git a/cli/azd/pkg/azdext/ssrf_guard_test.go b/cli/azd/pkg/azdext/ssrf_guard_test.go new file mode 100644 index 00000000000..c858b651731 --- /dev/null +++ b/cli/azd/pkg/azdext/ssrf_guard_test.go @@ -0,0 +1,617 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "errors" + "fmt" + "net" + "strings" + "testing" +) + +// --------------------------------------------------------------------------- +// SSRFGuard — metadata endpoint blocking +// --------------------------------------------------------------------------- + +func TestSSRFGuard_BlocksMetadataEndpoints(t *testing.T) { + guard := NewSSRFGuard().BlockMetadataEndpoints() + + blocked := []string{ + "http://169.254.169.254/latest/meta-data/", + "http://fd00:ec2::254/latest/meta-data/", + "http://metadata.google.internal/computeMetadata/v1/", + "http://100.100.100.200/latest/meta-data/", + // Case variations + "http://METADATA.GOOGLE.INTERNAL/computeMetadata/v1/", + // IPv4-mapped forms of metadata IPs — must be caught by IP normalization. + "http://[::ffff:169.254.169.254]/latest/meta-data/", + "http://[::ffff:100.100.100.200]/latest/meta-data/", + } + for _, u := range blocked { + if err := guard.Check(u); err == nil { + t.Errorf("Check(%s) = nil, want blocked (metadata)", u) + } + } +} + +// --------------------------------------------------------------------------- +// SSRFGuard — private network blocking +// --------------------------------------------------------------------------- + +func TestSSRFGuard_BlocksPrivateIPs(t *testing.T) { + guard := NewSSRFGuard().BlockPrivateNetworks() + + blocked := []struct { + url string + desc string + }{ + {"http://10.0.0.1/api", "RFC 1918 class A"}, + {"http://172.16.0.1/api", "RFC 1918 class B"}, + {"http://192.168.1.1/api", "RFC 1918 class C"}, + {"http://0.0.0.1/api", "0.0.0.0/8 'this' network"}, + {"http://100.64.0.1/api", "RFC 6598 CGNAT"}, + {"http://[fe80::1]/api", "IPv6 link-local"}, + {"http://[fd00::1]/api", "IPv6 unique local (fc00::/7)"}, + {"http://[fd12:3456:789a::1]/api", "IPv6 ULA in fd00::/8"}, + {"http://[::ffff:10.0.0.1]/api", "IPv4-mapped RFC 1918"}, + {"http://[::10.0.0.1]/api", "IPv4-compatible RFC 1918"}, + {"http://[2002:a00:1::]/api", "6to4 embedding 10.0.0.1"}, + {"http://[2001:0000::1]/api", "Teredo range"}, + {"http://[64:ff9b::a00:1]/api", "NAT64 well-known prefix"}, + {"http://[64:ff9b:1::a00:1]/api", "NAT64 local-use prefix"}, + {"http://[::ffff:0:a00:1]/api", "IPv4-translated RFC 1918 (RFC 2765)"}, + } + for _, tc := range blocked { + if err := guard.Check(tc.url); err == nil { + t.Errorf("Check(%s) = nil, want blocked (%s)", tc.url, tc.desc) + } + } +} + +func TestSSRFGuard_LocalhostExempt(t *testing.T) { + guard := NewSSRFGuard().BlockPrivateNetworks() + + // Localhost/loopback addresses are exempt from private network blocking + // to support local development workflows (API servers, proxies, etc.). + exempt := []struct { + url string + desc string + }{ + {"http://127.0.0.1/api", "IPv4 loopback"}, + {"http://localhost:8080/api", "localhost hostname"}, + {"http://[::1]:8080/api", "IPv6 loopback"}, + {"http://[::ffff:127.0.0.1]/api", "IPv4-mapped loopback"}, + } + for _, tc := range exempt { + if err := guard.Check(tc.url); err != nil { + t.Errorf("Check(%s) = %v, want nil (localhost exempt: %s)", tc.url, err, tc.desc) + } + } +} + +// --------------------------------------------------------------------------- +// SSRFGuard — allows public URLs +// --------------------------------------------------------------------------- + +func TestSSRFGuard_AllowsPublicURLs(t *testing.T) { + guard := NewSSRFGuard().BlockPrivateNetworks().BlockMetadataEndpoints() + // Mock DNS to avoid real network calls in tests. + guard.lookupHost = func(host string) ([]string, error) { + switch host { + case "api.github.com": + return []string{"140.82.121.6"}, nil + case "example.com": + return []string{"93.184.216.34"}, nil + default: + return nil, fmt.Errorf("unknown host: %s", host) + } + } + + allowed := []string{ + "https://api.github.com/repos", + "https://example.com/data", + "https://8.8.8.8/dns", + "https://[2607:f8b0:4004:800::200e]/data", // public IPv6 + } + for _, u := range allowed { + if err := guard.Check(u); err != nil { + t.Errorf("Check(%s) = %v, want nil", u, err) + } + } +} + +// --------------------------------------------------------------------------- +// SSRFGuard — HTTPS enforcement +// --------------------------------------------------------------------------- + +func TestSSRFGuard_EnforcesHTTPS(t *testing.T) { + guard := NewSSRFGuard().RequireHTTPS() + + // HTTP to external host should be blocked. + if err := guard.Check("http://example.com/api"); err == nil { + t.Error("Check(http://example.com/api) = nil, want HTTPS required error") + } + + // HTTP to localhost should be allowed. + if err := guard.Check("http://localhost:8080/api"); err != nil { + t.Errorf("Check(http://localhost:8080/api) = %v, want nil (localhost exempt)", err) + } + + // HTTP to 127.0.0.1 should be allowed. + if err := guard.Check("http://127.0.0.1:8080/api"); err != nil { + t.Errorf("Check(http://127.0.0.1:8080/api) = %v, want nil (loopback exempt)", err) + } + + // HTTP to [::1] should be allowed. + if err := guard.Check("http://[::1]:8080/api"); err != nil { + t.Errorf("Check(http://[::1]:8080/api) = %v, want nil (IPv6 loopback exempt)", err) + } + + // HTTPS should always be allowed. + if err := guard.Check("https://example.com/api"); err != nil { + t.Errorf("Check(https://example.com/api) = %v, want nil", err) + } +} + +// --------------------------------------------------------------------------- +// SSRFGuard — scheme blocking +// --------------------------------------------------------------------------- + +func TestSSRFGuard_BlocksExoticSchemes(t *testing.T) { + guard := NewSSRFGuard() // Even empty guard blocks non-HTTP schemes. + + blocked := []struct { + url string + desc string + }{ + {"ftp://example.com/file", "ftp"}, + {"gopher://example.com/path", "gopher"}, + {"file:///etc/passwd", "file"}, + {"//evil.com/path", "protocol-relative (empty scheme)"}, + {"ws://example.com/socket", "websocket"}, + {"wss://example.com/socket", "secure websocket"}, + {"ssh://example.com", "ssh"}, + {"telnet://example.com", "telnet"}, + {"ldap://example.com", "ldap"}, + {"dict://example.com", "dict"}, + {"jar://example.com", "jar"}, + } + for _, tc := range blocked { + err := guard.Check(tc.url) + if err == nil { + t.Errorf("Check(%s) = nil, want blocked (%s scheme)", tc.url, tc.desc) + continue + } + var ssrfErr *SSRFError + if !errors.As(err, &ssrfErr) { + t.Errorf("Check(%s) returned %T, want *SSRFError", tc.url, err) + continue + } + if ssrfErr.Reason != "scheme_blocked" { + t.Errorf("Check(%s).Reason = %q, want %q", tc.url, ssrfErr.Reason, "scheme_blocked") + } + } +} + +// --------------------------------------------------------------------------- +// SSRFGuard — DNS resolution +// --------------------------------------------------------------------------- + +func TestSSRFGuard_DNSResolvesToBlockedIP(t *testing.T) { + guard := NewSSRFGuard().BlockPrivateNetworks() + guard.lookupHost = func(host string) ([]string, error) { + return []string{"10.0.0.1"}, nil + } + + err := guard.Check("http://evil.example.com/steal") + if err == nil { + t.Error("Check should block URL resolving to private IP via DNS") + } + + var ssrfErr *SSRFError + if !errors.As(err, &ssrfErr) { + t.Fatalf("Check returned %T, want *SSRFError", err) + } + if ssrfErr.Reason != "blocked_ip" { + t.Errorf("Reason = %q, want %q", ssrfErr.Reason, "blocked_ip") + } +} + +func TestSSRFGuard_DNSResolvesToBlockedHost(t *testing.T) { + guard := NewSSRFGuard().BlockMetadataEndpoints() + guard.lookupHost = func(host string) ([]string, error) { + return []string{"169.254.169.254"}, nil + } + + err := guard.Check("http://evil.example.com/steal") + if err == nil { + t.Error("Check should block URL resolving to metadata IP via DNS") + } +} + +func TestSSRFGuard_DNSFailureBlocksRequest(t *testing.T) { + guard := NewSSRFGuard().BlockPrivateNetworks() + guard.lookupHost = func(host string) ([]string, error) { + return nil, fmt.Errorf("dns: NXDOMAIN") + } + + err := guard.Check("http://evil.example.com/steal") + if err == nil { + t.Fatal("Check should block URL when DNS resolution fails (fail-closed)") + } + + var ssrfErr *SSRFError + if !errors.As(err, &ssrfErr) { + t.Fatalf("Check returned %T, want *SSRFError", err) + } + if ssrfErr.Reason != "dns_failure" { + t.Errorf("Reason = %q, want %q", ssrfErr.Reason, "dns_failure") + } + if !strings.Contains(ssrfErr.Detail, "fail-closed") { + t.Errorf("Detail should mention fail-closed, got: %s", ssrfErr.Detail) + } +} + +func TestSSRFGuard_DNSMultipleAddresses(t *testing.T) { + guard := NewSSRFGuard().BlockPrivateNetworks() + guard.lookupHost = func(host string) ([]string, error) { + // First address is public, second is private — should still block. + return []string{"8.8.8.8", "192.168.1.1"}, nil + } + + err := guard.Check("http://dual-homed.example.com/api") + if err == nil { + t.Error("Check should block when any resolved IP is private") + } +} + +// --------------------------------------------------------------------------- +// SSRFGuard — allowlist +// --------------------------------------------------------------------------- + +func TestSSRFGuard_AllowHost(t *testing.T) { + guard := NewSSRFGuard(). + BlockPrivateNetworks(). + BlockMetadataEndpoints(). + AllowHost("internal.corp.example.com") + + // The allowed host should bypass all checks. + guard.lookupHost = func(host string) ([]string, error) { + if host == "internal.corp.example.com" { + return []string{"10.0.0.50"}, nil // would normally be blocked + } + return nil, fmt.Errorf("unknown host") + } + + if err := guard.Check("http://internal.corp.example.com/api"); err != nil { + t.Errorf("Check allowed host = %v, want nil", err) + } + + // Non-allowed hosts should still be blocked. + guard.lookupHost = func(host string) ([]string, error) { + return []string{"10.0.0.50"}, nil + } + if err := guard.Check("http://not-allowed.example.com/api"); err == nil { + t.Error("Check non-allowed host resolving to private IP = nil, want error") + } +} + +func TestSSRFGuard_AllowHostCaseInsensitive(t *testing.T) { + guard := NewSSRFGuard(). + BlockMetadataEndpoints(). + AllowHost("Allowed.Example.COM") + + guard.lookupHost = func(host string) ([]string, error) { + return []string{"1.2.3.4"}, nil + } + + if err := guard.Check("http://allowed.example.com/api"); err != nil { + t.Errorf("AllowHost should be case-insensitive, got: %v", err) + } +} + +// --------------------------------------------------------------------------- +// SSRFGuard — DefaultSSRFGuard preset +// --------------------------------------------------------------------------- + +func TestDefaultSSRFGuard(t *testing.T) { + guard := DefaultSSRFGuard() + + // Should block metadata. + if err := guard.Check("http://169.254.169.254/metadata"); err == nil { + t.Error("DefaultSSRFGuard should block metadata endpoint") + } + + // Should block private IPs. + if err := guard.Check("http://10.0.0.1/api"); err == nil { + t.Error("DefaultSSRFGuard should block private IPs") + } + + // Should require HTTPS. + if err := guard.Check("http://example.com/api"); err == nil { + t.Error("DefaultSSRFGuard should require HTTPS") + } + + // Should allow HTTPS public URLs. + if err := guard.Check("https://example.com/api"); err != nil { + t.Errorf("DefaultSSRFGuard should allow HTTPS public URL, got: %v", err) + } + + // Should allow HTTP to localhost. + if err := guard.Check("http://localhost:8080/api"); err != nil { + t.Errorf("DefaultSSRFGuard should allow HTTP to localhost, got: %v", err) + } +} + +// --------------------------------------------------------------------------- +// SSRFGuard — empty guard permissiveness +// --------------------------------------------------------------------------- + +func TestSSRFGuard_EmptyGuardAllowsHTTP(t *testing.T) { + guard := NewSSRFGuard() + + // Empty guard should allow HTTP and HTTPS but still block exotic schemes. + if err := guard.Check("http://example.com/api"); err != nil { + t.Errorf("empty guard should allow HTTP, got: %v", err) + } + if err := guard.Check("https://example.com/api"); err != nil { + t.Errorf("empty guard should allow HTTPS, got: %v", err) + } + if err := guard.Check("ftp://example.com/file"); err == nil { + t.Error("empty guard should still block FTP scheme") + } +} + +// --------------------------------------------------------------------------- +// SSRFGuard — edge cases +// --------------------------------------------------------------------------- + +func TestSSRFGuard_InvalidURL(t *testing.T) { + guard := NewSSRFGuard() + + err := guard.Check("://invalid") + if err == nil { + t.Error("Check(invalid URL) = nil, want error") + return + } + var ssrfErr *SSRFError + if !errors.As(err, &ssrfErr) { + t.Errorf("Check returned %T, want *SSRFError", err) + return + } + if ssrfErr.Reason != "scheme_blocked" && ssrfErr.Reason != "invalid_url" { + t.Errorf("Reason = %q, want scheme_blocked or invalid_url", ssrfErr.Reason) + } +} + +func TestSSRFGuard_URLTruncation(t *testing.T) { + guard := NewSSRFGuard().RequireHTTPS() + + // Create a very long URL. + longURL := "http://" + strings.Repeat("a", 300) + ".com/path" + err := guard.Check(longURL) + if err == nil { + t.Fatal("Check(long http URL) = nil, want HTTPS error") + } + + var ssrfErr *SSRFError + if !errors.As(err, &ssrfErr) { + t.Fatalf("Check returned %T, want *SSRFError", err) + } + // URL should be truncated in the error to avoid log flooding. + if len(ssrfErr.URL) > 210 { + t.Errorf("SSRFError.URL should be truncated, got length %d", len(ssrfErr.URL)) + } +} + +// --------------------------------------------------------------------------- +// SSRFGuard — concurrent safety +// --------------------------------------------------------------------------- + +func TestSSRFGuard_ConcurrentCheck(t *testing.T) { + guard := DefaultSSRFGuard() + + done := make(chan struct{}) + for range 100 { + go func() { + defer func() { done <- struct{}{} }() + _ = guard.Check("https://example.com/api") + _ = guard.Check("http://10.0.0.1/api") + }() + } + for range 100 { + <-done + } +} + +// --------------------------------------------------------------------------- +// SSRFGuard — obfuscated IP host formats +// --------------------------------------------------------------------------- + +// TestSSRFGuard_ObfuscatedIPFormats documents and verifies SSRF safety for +// various obfuscated IP address formats that attackers may use to bypass +// hostname-based blocklists. These are non-standard host representations +// that some URL parsers/HTTP clients accept and resolve to private IPs. +func TestSSRFGuard_ObfuscatedIPFormats(t *testing.T) { + guard := NewSSRFGuard().BlockPrivateNetworks().BlockMetadataEndpoints() + + // Obfuscated IPv4 formats — these are blocked because Go's url.Parse + // normalizes or rejects them, and our IP checks catch the resolved form. + tests := []struct { + url string + desc string + expect string // "block" or "block_or_error" (parser may reject) + }{ + // Decimal encoding of 127.0.0.1 (2130706433 = 0x7F000001). + // Go's net.ParseIP does NOT parse decimal IPs, so this falls through + // to DNS resolution which will fail (fail-closed → blocked). + {"http://2130706433/api", "decimal IP (127.0.0.1)", "block_or_error"}, + + // Octal encoding: 0177.0.0.1 = 127.0.0.1. + // Go's net.ParseIP rejects octal, so DNS lookup fails (fail-closed). + {"http://0177.0.0.1/api", "octal IP (127.0.0.1)", "block_or_error"}, + + // Hex encoding: 0x7f000001 = 127.0.0.1. + // Go's net.ParseIP rejects hex integers → DNS fail-closed. + {"http://0x7f000001/api", "hex IP (127.0.0.1)", "block_or_error"}, + + // IPv6 bracket notation with private IPv4-mapped address. + {"http://[::ffff:10.0.0.1]/api", "IPv4-mapped private in brackets", "block"}, + + // IPv4-compatible IPv6 embedding private IP. + {"http://[::10.0.0.1]/api", "IPv4-compatible private (::10.0.0.1)", "block"}, + + // Zero-padded IPv4 octets (non-standard; rejected by Go parser). + {"http://127.000.000.001/api", "zero-padded loopback", "block_or_error"}, + + // Mixed-case metadata hostname. + {"http://Metadata.Google.Internal/api", "mixed-case metadata host", "block"}, + + // IPv6 with zone ID — net.ParseIP rejects zone IDs, blocked on parse. + {"http://[fe80::1%25eth0]/api", "IPv6 link-local with zone", "block"}, + } + + for _, tc := range tests { + t.Run(tc.desc, func(t *testing.T) { + // Override DNS to fail-closed (prevents real network calls). + guard.lookupHost = func(host string) ([]string, error) { + return nil, fmt.Errorf("dns: NXDOMAIN (test)") + } + + err := guard.Check(tc.url) + if err == nil { + t.Errorf("Check(%s) = nil, want blocked (%s)", tc.url, tc.desc) + } + }) + } +} + +// --------------------------------------------------------------------------- +// SSRFError +// --------------------------------------------------------------------------- + +func TestSSRFError_ErrorMessage(t *testing.T) { + err := &SSRFError{ + URL: "http://evil.com/steal", + Reason: "blocked_ip", + Detail: "IP 10.0.0.1 is private", + } + + msg := err.Error() + if !strings.Contains(msg, "blocked_ip") { + t.Errorf("error message should contain reason, got: %s", msg) + } + if !strings.Contains(msg, "evil.com") { + t.Errorf("error message should contain URL, got: %s", msg) + } + if !strings.Contains(msg, "10.0.0.1") { + t.Errorf("error message should contain detail, got: %s", msg) + } +} + +// --------------------------------------------------------------------------- +// IPv6 embedding extraction +// --------------------------------------------------------------------------- + +func TestExtractIPv4Compatible(t *testing.T) { + tests := []struct { + name string + input string + wantV4 string + }{ + {"loopback", "::127.0.0.1", "127.0.0.1"}, + {"private", "::10.0.0.1", "10.0.0.1"}, + {"public", "::8.8.8.8", "8.8.8.8"}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ip := parseIPv6(t, tc.input) + v4 := extractIPv4Compatible(ip) + if v4 == nil { + t.Fatal("extractIPv4Compatible returned nil") + } + if !v4.Equal(parseIP(t, tc.wantV4)) { + t.Errorf("extractIPv4Compatible(%s) = %s, want %s", tc.input, v4, tc.wantV4) + } + }) + } +} + +func TestExtractIPv4Compatible_ReturnsNil(t *testing.T) { + tests := []struct { + name string + input string + }{ + {"normal_ipv6", "2001:db8::1"}, + {"ipv4_mapped", "::ffff:10.0.0.1"}, // To4() != nil, so not pure IPv6 + {"all_zeros", "::"}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ip := parseIPv6(t, tc.input) + // IPv4-mapped addresses have To4() != nil and won't reach extractIPv4Compatible. + // For testing the extraction function directly, skip those. + if ip.To4() != nil { + t.Skip("IPv4-mapped; To4() != nil") + } + v4 := extractIPv4Compatible(ip) + if v4 != nil { + t.Errorf("extractIPv4Compatible(%s) = %s, want nil", tc.input, v4) + } + }) + } +} + +func TestExtractIPv4Translated(t *testing.T) { + tests := []struct { + name string + input string + wantV4 string + }{ + {"loopback", "::ffff:0:127.0.0.1", "127.0.0.1"}, + {"private", "::ffff:0:10.0.0.1", "10.0.0.1"}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ip := parseIPv6(t, tc.input) + if ip.To4() != nil { + t.Skip("To4() != nil; not a pure IPv6") + } + v4 := extractIPv4Translated(ip) + if v4 == nil { + t.Fatal("extractIPv4Translated returned nil") + } + if !v4.Equal(parseIP(t, tc.wantV4)) { + t.Errorf("extractIPv4Translated(%s) = %s, want %s", tc.input, v4, tc.wantV4) + } + }) + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +func parseIP(t *testing.T, s string) net.IP { + t.Helper() + ip := net.ParseIP(s) + if ip == nil { + t.Fatalf("failed to parse IP: %s", s) + } + return ip +} + +func parseIPv6(t *testing.T, s string) net.IP { + t.Helper() + ip := net.ParseIP(s) + if ip == nil { + t.Fatalf("failed to parse IPv6: %s", s) + } + // Ensure we have a 16-byte representation. + if len(ip) != net.IPv6len { + ip = ip.To16() + } + return ip +}