diff --git a/cli/azd/pkg/azdext/config_helper.go b/cli/azd/pkg/azdext/config_helper.go new file mode 100644 index 00000000000..843729dc5d4 --- /dev/null +++ b/cli/azd/pkg/azdext/config_helper.go @@ -0,0 +1,457 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "regexp" + "strings" +) + +// ConfigHelper provides typed, ergonomic access to azd configuration through +// the gRPC UserConfig and Environment services. It eliminates the boilerplate +// of raw gRPC calls and JSON marshaling that extension authors otherwise need. +// +// Configuration sources (in merge priority, lowest to highest): +// 1. User config (global azd config) — via UserConfigService +// 2. Environment config (per-env) — via EnvironmentService +// +// Usage: +// +// ch := azdext.NewConfigHelper(client) +// port, err := ch.GetUserString(ctx, "extensions.myext.port") +// var cfg MyConfig +// err = ch.GetUserJSON(ctx, "extensions.myext", &cfg) +type ConfigHelper struct { + client *AzdClient +} + +// NewConfigHelper creates a [ConfigHelper] for the given AZD client. +func NewConfigHelper(client *AzdClient) (*ConfigHelper, error) { + if client == nil { + return nil, errors.New("azdext.NewConfigHelper: client must not be nil") + } + + return &ConfigHelper{client: client}, nil +} + +// --- User Config (global) --- + +// GetUserString retrieves a string value from the global user config at the +// given dot-separated path. Returns ("", false, nil) when the path does not +// exist, and ("", false, err) on gRPC errors. +func (ch *ConfigHelper) GetUserString(ctx context.Context, path string) (string, bool, error) { + if err := validatePath(path); err != nil { + return "", false, err + } + + resp, err := ch.client.UserConfig().GetString(ctx, &GetUserConfigStringRequest{Path: path}) + if err != nil { + return "", false, fmt.Errorf("azdext.ConfigHelper.GetUserString: gRPC call failed for path %q: %w", path, err) + } + + return resp.GetValue(), resp.GetFound(), nil +} + +// GetUserJSON retrieves a value from the global user config and unmarshals it +// into out. Returns (false, nil) when the path does not exist. +func (ch *ConfigHelper) GetUserJSON(ctx context.Context, path string, out any) (bool, error) { + if err := validatePath(path); err != nil { + return false, err + } + + if out == nil { + return false, errors.New("azdext.ConfigHelper.GetUserJSON: out must not be nil") + } + + resp, err := ch.client.UserConfig().Get(ctx, &GetUserConfigRequest{Path: path}) + if err != nil { + return false, fmt.Errorf("azdext.ConfigHelper.GetUserJSON: gRPC call failed for path %q: %w", path, err) + } + + if !resp.GetFound() { + return false, nil + } + + data := resp.GetValue() + if len(data) == 0 { + return false, nil + } + + if err := json.Unmarshal(data, out); err != nil { + return true, &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("failed to unmarshal config at path %q: %w", path, err), + } + } + + return true, nil +} + +// SetUserJSON marshals value as JSON and writes it to the global user config +// at the given path. +func (ch *ConfigHelper) SetUserJSON(ctx context.Context, path string, value any) error { + if err := validatePath(path); err != nil { + return err + } + + if value == nil { + return errors.New("azdext.ConfigHelper.SetUserJSON: value must not be nil") + } + + data, err := json.Marshal(value) + if err != nil { + return &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("failed to marshal value for path %q: %w", path, err), + } + } + + _, err = ch.client.UserConfig().Set(ctx, &SetUserConfigRequest{ + Path: path, + Value: data, + }) + if err != nil { + return fmt.Errorf("azdext.ConfigHelper.SetUserJSON: gRPC call failed for path %q: %w", path, err) + } + + return nil +} + +// UnsetUser removes a value from the global user config. +func (ch *ConfigHelper) UnsetUser(ctx context.Context, path string) error { + if err := validatePath(path); err != nil { + return err + } + + _, err := ch.client.UserConfig().Unset(ctx, &UnsetUserConfigRequest{Path: path}) + if err != nil { + return fmt.Errorf("azdext.ConfigHelper.UnsetUser: gRPC call failed for path %q: %w", path, err) + } + + return nil +} + +// --- Environment Config (per-environment) --- + +// GetEnvString retrieves a string config value from the current environment. +// Returns ("", false, nil) when the path does not exist. +func (ch *ConfigHelper) GetEnvString(ctx context.Context, path string) (string, bool, error) { + if err := validatePath(path); err != nil { + return "", false, err + } + + resp, err := ch.client.Environment().GetConfigString(ctx, &GetConfigStringRequest{Path: path}) + if err != nil { + return "", false, fmt.Errorf("azdext.ConfigHelper.GetEnvString: gRPC call failed for path %q: %w", path, err) + } + + return resp.GetValue(), resp.GetFound(), nil +} + +// GetEnvJSON retrieves a value from the current environment's config and +// unmarshals it into out. Returns (false, nil) when the path does not exist. +func (ch *ConfigHelper) GetEnvJSON(ctx context.Context, path string, out any) (bool, error) { + if err := validatePath(path); err != nil { + return false, err + } + + if out == nil { + return false, errors.New("azdext.ConfigHelper.GetEnvJSON: out must not be nil") + } + + resp, err := ch.client.Environment().GetConfig(ctx, &GetConfigRequest{Path: path}) + if err != nil { + return false, fmt.Errorf("azdext.ConfigHelper.GetEnvJSON: gRPC call failed for path %q: %w", path, err) + } + + if !resp.GetFound() { + return false, nil + } + + data := resp.GetValue() + if len(data) == 0 { + return false, nil + } + + if err := json.Unmarshal(data, out); err != nil { + return true, &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("failed to unmarshal env config at path %q: %w", path, err), + } + } + + return true, nil +} + +// SetEnvJSON marshals value as JSON and writes it to the current environment's config. +func (ch *ConfigHelper) SetEnvJSON(ctx context.Context, path string, value any) error { + if err := validatePath(path); err != nil { + return err + } + + if value == nil { + return errors.New("azdext.ConfigHelper.SetEnvJSON: value must not be nil") + } + + data, err := json.Marshal(value) + if err != nil { + return &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("failed to marshal value for env config path %q: %w", path, err), + } + } + + _, err = ch.client.Environment().SetConfig(ctx, &SetConfigRequest{ + Path: path, + Value: data, + }) + if err != nil { + return fmt.Errorf("azdext.ConfigHelper.SetEnvJSON: gRPC call failed for path %q: %w", path, err) + } + + return nil +} + +// UnsetEnv removes a value from the current environment's config. +func (ch *ConfigHelper) UnsetEnv(ctx context.Context, path string) error { + if err := validatePath(path); err != nil { + return err + } + + _, err := ch.client.Environment().UnsetConfig(ctx, &UnsetConfigRequest{Path: path}) + if err != nil { + return fmt.Errorf("azdext.ConfigHelper.UnsetEnv: gRPC call failed for path %q: %w", path, err) + } + + return nil +} + +// --- Merge --- + +// MergeJSON performs a shallow merge of override into base, returning a new map. +// Both inputs must be JSON-compatible maps (map[string]any). Keys in override +// take precedence over keys in base. +// +// This is NOT a deep merge — nested maps are replaced entirely by the override +// value. For predictable extension config behavior, keep config structures flat +// or use explicit path-based Set operations for nested values. +func MergeJSON(base, override map[string]any) map[string]any { + merged := make(map[string]any, len(base)+len(override)) + + for k, v := range base { + merged[k] = v + } + + for k, v := range override { + merged[k] = v + } + + return merged +} + +// deepMergeMaxDepth is the maximum recursion depth for [DeepMergeJSON]. +// This prevents stack overflow from deeply nested or adversarial JSON +// structures. 32 levels is far deeper than any legitimate config hierarchy. +const deepMergeMaxDepth = 32 + +// DeepMergeJSON performs a recursive merge of override into base. +// When both base and override have a map value for the same key, those maps +// are merged recursively. Otherwise the override value replaces the base value. +// +// Recursion is bounded to [deepMergeMaxDepth] levels to prevent stack overflow +// from deeply nested or adversarial inputs. Beyond the limit, the override +// value replaces the base value (merge degrades to shallow at that level). +func DeepMergeJSON(base, override map[string]any) map[string]any { + return deepMergeJSON(base, override, 0) +} + +func deepMergeJSON(base, override map[string]any, depth int) map[string]any { + merged := make(map[string]any, len(base)+len(override)) + + for k, v := range base { + merged[k] = v + } + + for k, v := range override { + baseVal, exists := merged[k] + if !exists { + merged[k] = v + continue + } + + baseMap, baseIsMap := baseVal.(map[string]any) + overMap, overIsMap := v.(map[string]any) + + if baseIsMap && overIsMap && depth < deepMergeMaxDepth { + merged[k] = deepMergeJSON(baseMap, overMap, depth+1) + } else { + merged[k] = v + } + } + + return merged +} + +// --- Validation --- + +// ConfigValidator defines a function that validates a config value. +// It returns nil if valid, or an error describing the validation failure. +type ConfigValidator func(value any) error + +// ValidateConfig unmarshals the raw JSON data and runs all supplied validators. +// Returns the first validation error encountered, wrapped in a [*ConfigError]. +func ValidateConfig(path string, data []byte, validators ...ConfigValidator) error { + if len(data) == 0 { + return &ConfigError{ + Path: path, + Reason: ConfigReasonMissing, + Err: fmt.Errorf("config at path %q is empty", path), + } + } + + var value any + if err := json.Unmarshal(data, &value); err != nil { + return &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("config at path %q is not valid JSON: %w", path, err), + } + } + + for _, v := range validators { + if err := v(value); err != nil { + return &ConfigError{ + Path: path, + Reason: ConfigReasonValidationFailed, + Err: fmt.Errorf("config validation failed at path %q: %w", path, err), + } + } + } + + return nil +} + +// RequiredKeys returns a [ConfigValidator] that checks for the presence of +// the specified keys in a map value. +func RequiredKeys(keys ...string) ConfigValidator { + return func(value any) error { + m, ok := value.(map[string]any) + if !ok { + return fmt.Errorf("expected object, got %T", value) + } + + for _, key := range keys { + if _, exists := m[key]; !exists { + return fmt.Errorf("required key %q is missing", key) + } + } + + return nil + } +} + +// --- Error types --- + +// ConfigReason classifies the cause of a [ConfigError]. +type ConfigReason int + +const ( + // ConfigReasonMissing indicates the config path does not exist or is empty. + ConfigReasonMissing ConfigReason = iota + + // ConfigReasonInvalidFormat indicates the config value is not valid JSON + // or cannot be unmarshaled into the target type. + ConfigReasonInvalidFormat + + // ConfigReasonValidationFailed indicates a validator rejected the config value. + ConfigReasonValidationFailed +) + +// String returns a human-readable label. +func (r ConfigReason) String() string { + switch r { + case ConfigReasonMissing: + return "missing" + case ConfigReasonInvalidFormat: + return "invalid_format" + case ConfigReasonValidationFailed: + return "validation_failed" + default: + return "unknown" + } +} + +// ConfigError is returned by [ConfigHelper] methods on domain-level failures. +type ConfigError struct { + // Path is the config path that was being accessed. + Path string + + // Reason classifies the failure. + Reason ConfigReason + + // Err is the underlying error. + Err error +} + +func (e *ConfigError) Error() string { + return fmt.Sprintf("azdext.ConfigHelper: %s (path=%s): %v", e.Reason, e.Path, e.Err) +} + +func (e *ConfigError) Unwrap() error { + return e.Err +} + +// configSegmentRe matches a single segment: starts with alphanumeric, +// contains only [a-zA-Z0-9_-], and is 1–63 characters. +var configSegmentRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]{0,62}$`) + +// validatePath checks that a config path is non-empty, does not contain +// consecutive dots or empty segments, and that each dot-separated segment +// conforms to a restricted character set. +// +// Valid: "extensions.myext.port", "my-key", "a.b.c" +// Invalid: "", "..foo", "a..b", ".leading", "trailing.", "a.b.c.d.e.f" (okay, but each segment validated) +func validatePath(path string) error { + if path == "" { + return errors.New("azdext.ConfigHelper: config path must not be empty") + } + + // Reject leading/trailing dots and consecutive dots (empty segments). + if strings.HasPrefix(path, ".") || strings.HasSuffix(path, ".") || strings.Contains(path, "..") { + return errors.New( + "azdext.ConfigHelper: config path must not have empty segments " + + "(no leading/trailing dots or consecutive dots)", + ) + } + + // Validate each dot-separated segment individually. + segments := strings.Split(path, ".") + for _, seg := range segments { + if !configSegmentRe.MatchString(seg) { + return fmt.Errorf( + "azdext.ConfigHelper: config path segment %q must start with alphanumeric "+ + "and contain only [a-zA-Z0-9_-], max 63 chars", + truncateConfigValue(seg, 64), + ) + } + } + + return nil +} + +// truncateConfigValue truncates s for safe inclusion in error messages. +func truncateConfigValue(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} diff --git a/cli/azd/pkg/azdext/config_helper_test.go b/cli/azd/pkg/azdext/config_helper_test.go new file mode 100644 index 00000000000..ba5c80e6321 --- /dev/null +++ b/cli/azd/pkg/azdext/config_helper_test.go @@ -0,0 +1,977 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "google.golang.org/grpc" +) + +// --- Stub UserConfigService --- + +type stubUserConfigService struct { + getResp *GetUserConfigResponse + getStringResp *GetUserConfigStringResponse + getSectionErr error + getErr error + getStringErr error + setErr error + unsetErr error +} + +func (s *stubUserConfigService) Get( + _ context.Context, _ *GetUserConfigRequest, _ ...grpc.CallOption, +) (*GetUserConfigResponse, error) { + return s.getResp, s.getErr +} + +func (s *stubUserConfigService) GetString( + _ context.Context, _ *GetUserConfigStringRequest, _ ...grpc.CallOption, +) (*GetUserConfigStringResponse, error) { + return s.getStringResp, s.getStringErr +} + +func (s *stubUserConfigService) GetSection( + _ context.Context, _ *GetUserConfigSectionRequest, _ ...grpc.CallOption, +) (*GetUserConfigSectionResponse, error) { + return nil, s.getSectionErr +} + +func (s *stubUserConfigService) Set( + _ context.Context, _ *SetUserConfigRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return &EmptyResponse{}, s.setErr +} + +func (s *stubUserConfigService) Unset( + _ context.Context, _ *UnsetUserConfigRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return &EmptyResponse{}, s.unsetErr +} + +// --- Stub EnvironmentService --- + +type stubEnvironmentService struct { + getConfigResp *GetConfigResponse + getConfigStringResp *GetConfigStringResponse + getConfigErr error + getConfigStringErr error + setConfigErr error + unsetConfigErr error +} + +func (s *stubEnvironmentService) GetCurrent( + _ context.Context, _ *EmptyRequest, _ ...grpc.CallOption, +) (*EnvironmentResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) List( + _ context.Context, _ *EmptyRequest, _ ...grpc.CallOption, +) (*EnvironmentListResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) Get( + _ context.Context, _ *GetEnvironmentRequest, _ ...grpc.CallOption, +) (*EnvironmentResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) Select( + _ context.Context, _ *SelectEnvironmentRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) GetValues( + _ context.Context, _ *GetEnvironmentRequest, _ ...grpc.CallOption, +) (*KeyValueListResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) GetValue( + _ context.Context, _ *GetEnvRequest, _ ...grpc.CallOption, +) (*KeyValueResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) SetValue( + _ context.Context, _ *SetEnvRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) GetConfig( + _ context.Context, _ *GetConfigRequest, _ ...grpc.CallOption, +) (*GetConfigResponse, error) { + return s.getConfigResp, s.getConfigErr +} + +func (s *stubEnvironmentService) GetConfigString( + _ context.Context, _ *GetConfigStringRequest, _ ...grpc.CallOption, +) (*GetConfigStringResponse, error) { + return s.getConfigStringResp, s.getConfigStringErr +} + +func (s *stubEnvironmentService) GetConfigSection( + _ context.Context, _ *GetConfigSectionRequest, _ ...grpc.CallOption, +) (*GetConfigSectionResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) SetConfig( + _ context.Context, _ *SetConfigRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return &EmptyResponse{}, s.setConfigErr +} + +func (s *stubEnvironmentService) UnsetConfig( + _ context.Context, _ *UnsetConfigRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return &EmptyResponse{}, s.unsetConfigErr +} + +// --- NewConfigHelper --- + +func TestNewConfigHelper_NilClient(t *testing.T) { + t.Parallel() + + _, err := NewConfigHelper(nil) + if err == nil { + t.Fatal("expected error for nil client") + } +} + +func TestNewConfigHelper_Success(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, err := NewConfigHelper(client) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if ch == nil { + t.Fatal("expected non-nil ConfigHelper") + } +} + +// --- GetUserString --- + +func TestGetUserString_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + _, _, err := ch.GetUserString(context.Background(), "") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +func TestGetUserString_Found(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getStringResp: &GetUserConfigStringResponse{Value: "8080", Found: true}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + val, found, err := ch.GetUserString(context.Background(), "extensions.myext.port") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !found { + t.Error("expected found = true") + } + + if val != "8080" { + t.Errorf("value = %q, want %q", val, "8080") + } +} + +func TestGetUserString_NotFound(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getStringResp: &GetUserConfigStringResponse{Value: "", Found: false}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + val, found, err := ch.GetUserString(context.Background(), "nonexistent.path") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if found { + t.Error("expected found = false") + } + + if val != "" { + t.Errorf("value = %q, want empty", val) + } +} + +func TestGetUserString_GRPCError(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getStringErr: errors.New("grpc unavailable"), + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + _, _, err := ch.GetUserString(context.Background(), "some.path") + if err == nil { + t.Fatal("expected error for gRPC failure") + } +} + +// --- GetUserJSON --- + +func TestGetUserJSON_Found(t *testing.T) { + t.Parallel() + + type myConfig struct { + Port int `json:"port"` + Host string `json:"host"` + } + + data, _ := json.Marshal(myConfig{Port: 3000, Host: "localhost"}) + stub := &stubUserConfigService{ + getResp: &GetUserConfigResponse{Value: data, Found: true}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg myConfig + found, err := ch.GetUserJSON(context.Background(), "extensions.myext", &cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !found { + t.Error("expected found = true") + } + + if cfg.Port != 3000 { + t.Errorf("Port = %d, want 3000", cfg.Port) + } + + if cfg.Host != "localhost" { + t.Errorf("Host = %q, want %q", cfg.Host, "localhost") + } +} + +func TestGetUserJSON_NotFound(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getResp: &GetUserConfigResponse{Value: nil, Found: false}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg map[string]any + found, err := ch.GetUserJSON(context.Background(), "nonexistent", &cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if found { + t.Error("expected found = false") + } +} + +func TestGetUserJSON_NilOut(t *testing.T) { + t.Parallel() + + client := &AzdClient{userConfigClient: &stubUserConfigService{}} + ch, _ := NewConfigHelper(client) + + _, err := ch.GetUserJSON(context.Background(), "some.path", nil) + if err == nil { + t.Fatal("expected error for nil out parameter") + } +} + +func TestGetUserJSON_InvalidJSON(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getResp: &GetUserConfigResponse{Value: []byte("not json"), Found: true}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg map[string]any + _, err := ch.GetUserJSON(context.Background(), "bad.json", &cfg) + if err == nil { + t.Fatal("expected error for invalid JSON") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonInvalidFormat { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonInvalidFormat) + } +} + +func TestGetUserJSON_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + var cfg map[string]any + _, err := ch.GetUserJSON(context.Background(), "", &cfg) + if err == nil { + t.Fatal("expected error for empty path") + } +} + +// --- SetUserJSON --- + +func TestSetUserJSON_Success(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{} + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.SetUserJSON(context.Background(), "extensions.myext.port", 3000) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestSetUserJSON_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + err := ch.SetUserJSON(context.Background(), "", "value") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +func TestSetUserJSON_NilValue(t *testing.T) { + t.Parallel() + + client := &AzdClient{userConfigClient: &stubUserConfigService{}} + ch, _ := NewConfigHelper(client) + + err := ch.SetUserJSON(context.Background(), "some.path", nil) + if err == nil { + t.Fatal("expected error for nil value") + } +} + +func TestSetUserJSON_GRPCError(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + setErr: errors.New("grpc write error"), + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.SetUserJSON(context.Background(), "some.path", "value") + if err == nil { + t.Fatal("expected error for gRPC failure") + } +} + +func TestSetUserJSON_UnmarshalableValue(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{} + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + // Channels cannot be marshaled to JSON + err := ch.SetUserJSON(context.Background(), "some.path", make(chan int)) + if err == nil { + t.Fatal("expected error for unmarshalable value") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonInvalidFormat { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonInvalidFormat) + } +} + +// --- UnsetUser --- + +func TestUnsetUser_Success(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{} + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.UnsetUser(context.Background(), "some.path") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestUnsetUser_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + err := ch.UnsetUser(context.Background(), "") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +// --- GetEnvString --- + +func TestGetEnvString_Found(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{ + getConfigStringResp: &GetConfigStringResponse{Value: "prod", Found: true}, + } + + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + val, found, err := ch.GetEnvString(context.Background(), "extensions.myext.mode") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !found { + t.Error("expected found = true") + } + + if val != "prod" { + t.Errorf("value = %q, want %q", val, "prod") + } +} + +func TestGetEnvString_NotFound(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{ + getConfigStringResp: &GetConfigStringResponse{Value: "", Found: false}, + } + + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + _, found, err := ch.GetEnvString(context.Background(), "nonexistent") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if found { + t.Error("expected found = false") + } +} + +func TestGetEnvString_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + _, _, err := ch.GetEnvString(context.Background(), "") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +// --- GetEnvJSON --- + +func TestGetEnvJSON_Found(t *testing.T) { + t.Parallel() + + type envConfig struct { + Debug bool `json:"debug"` + } + + data, _ := json.Marshal(envConfig{Debug: true}) + stub := &stubEnvironmentService{ + getConfigResp: &GetConfigResponse{Value: data, Found: true}, + } + + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg envConfig + found, err := ch.GetEnvJSON(context.Background(), "extensions.myext", &cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !found { + t.Error("expected found = true") + } + + if !cfg.Debug { + t.Error("expected Debug = true") + } +} + +func TestGetEnvJSON_NotFound(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{ + getConfigResp: &GetConfigResponse{Value: nil, Found: false}, + } + + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg map[string]any + found, err := ch.GetEnvJSON(context.Background(), "nonexistent", &cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if found { + t.Error("expected found = false") + } +} + +func TestGetEnvJSON_NilOut(t *testing.T) { + t.Parallel() + + client := &AzdClient{environmentClient: &stubEnvironmentService{}} + ch, _ := NewConfigHelper(client) + + _, err := ch.GetEnvJSON(context.Background(), "some.path", nil) + if err == nil { + t.Fatal("expected error for nil out parameter") + } +} + +// --- SetEnvJSON --- + +func TestSetEnvJSON_Success(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{} + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.SetEnvJSON(context.Background(), "extensions.myext.mode", "prod") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestSetEnvJSON_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + err := ch.SetEnvJSON(context.Background(), "", "value") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +func TestSetEnvJSON_NilValue(t *testing.T) { + t.Parallel() + + client := &AzdClient{environmentClient: &stubEnvironmentService{}} + ch, _ := NewConfigHelper(client) + + err := ch.SetEnvJSON(context.Background(), "some.path", nil) + if err == nil { + t.Fatal("expected error for nil value") + } +} + +// --- UnsetEnv --- + +func TestUnsetEnv_Success(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{} + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.UnsetEnv(context.Background(), "some.path") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestUnsetEnv_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + err := ch.UnsetEnv(context.Background(), "") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +// --- MergeJSON --- + +func TestMergeJSON_Basic(t *testing.T) { + t.Parallel() + + base := map[string]any{"a": 1, "b": 2} + override := map[string]any{"b": 3, "c": 4} + + result := MergeJSON(base, override) + + if result["a"] != 1 { + t.Errorf("a = %v, want 1", result["a"]) + } + + if result["b"] != 3 { + t.Errorf("b = %v, want 3 (override wins)", result["b"]) + } + + if result["c"] != 4 { + t.Errorf("c = %v, want 4", result["c"]) + } +} + +func TestMergeJSON_EmptyBase(t *testing.T) { + t.Parallel() + + result := MergeJSON(nil, map[string]any{"x": "y"}) + + if result["x"] != "y" { + t.Errorf("x = %v, want y", result["x"]) + } +} + +func TestMergeJSON_EmptyOverride(t *testing.T) { + t.Parallel() + + result := MergeJSON(map[string]any{"x": "y"}, nil) + + if result["x"] != "y" { + t.Errorf("x = %v, want y", result["x"]) + } +} + +func TestMergeJSON_BothEmpty(t *testing.T) { + t.Parallel() + + result := MergeJSON(nil, nil) + + if len(result) != 0 { + t.Errorf("len(result) = %d, want 0", len(result)) + } +} + +func TestMergeJSON_DoesNotMutateInputs(t *testing.T) { + t.Parallel() + + base := map[string]any{"a": 1} + override := map[string]any{"b": 2} + + _ = MergeJSON(base, override) + + if _, ok := base["b"]; ok { + t.Error("MergeJSON mutated base map") + } + + if _, ok := override["a"]; ok { + t.Error("MergeJSON mutated override map") + } +} + +// --- DeepMergeJSON --- + +func TestDeepMergeJSON_RecursiveMerge(t *testing.T) { + t.Parallel() + + base := map[string]any{ + "server": map[string]any{ + "host": "localhost", + "port": 3000, + }, + "debug": false, + } + + override := map[string]any{ + "server": map[string]any{ + "port": 8080, + "tls": true, + }, + "version": "1.0", + } + + result := DeepMergeJSON(base, override) + + server, ok := result["server"].(map[string]any) + if !ok { + t.Fatal("server should be a map") + } + + if server["host"] != "localhost" { + t.Errorf("server.host = %v, want localhost", server["host"]) + } + + if server["port"] != 8080 { + t.Errorf("server.port = %v, want 8080 (override wins)", server["port"]) + } + + if server["tls"] != true { + t.Errorf("server.tls = %v, want true", server["tls"]) + } + + if result["debug"] != false { + t.Errorf("debug = %v, want false", result["debug"]) + } + + if result["version"] != "1.0" { + t.Errorf("version = %v, want 1.0", result["version"]) + } +} + +func TestDeepMergeJSON_OverrideReplacesNonMap(t *testing.T) { + t.Parallel() + + base := map[string]any{"x": "string-value"} + override := map[string]any{"x": map[string]any{"nested": true}} + + result := DeepMergeJSON(base, override) + + nested, ok := result["x"].(map[string]any) + if !ok { + t.Fatal("override should replace string with map") + } + + if nested["nested"] != true { + t.Errorf("x.nested = %v, want true", nested["nested"]) + } +} + +func TestDeepMergeJSON_DoesNotMutateInputs(t *testing.T) { + t.Parallel() + + base := map[string]any{"a": map[string]any{"x": 1}} + override := map[string]any{"a": map[string]any{"y": 2}} + + _ = DeepMergeJSON(base, override) + + baseA := base["a"].(map[string]any) + if _, ok := baseA["y"]; ok { + t.Error("DeepMergeJSON mutated base nested map") + } +} + +// --- ValidateConfig --- + +func TestValidateConfig_EmptyData(t *testing.T) { + t.Parallel() + + err := ValidateConfig("test.path", nil) + if err == nil { + t.Fatal("expected error for empty data") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonMissing { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonMissing) + } +} + +func TestValidateConfig_InvalidJSON(t *testing.T) { + t.Parallel() + + err := ValidateConfig("test.path", []byte("not json")) + if err == nil { + t.Fatal("expected error for invalid JSON") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonInvalidFormat { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonInvalidFormat) + } +} + +func TestValidateConfig_ValidatorFails(t *testing.T) { + t.Parallel() + + data, _ := json.Marshal(map[string]any{"a": 1}) + failValidator := func(_ any) error { return errors.New("validation failed") } + + err := ValidateConfig("test.path", data, failValidator) + if err == nil { + t.Fatal("expected error from failing validator") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonValidationFailed { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonValidationFailed) + } +} + +func TestValidateConfig_AllValidatorsPass(t *testing.T) { + t.Parallel() + + data, _ := json.Marshal(map[string]any{"a": 1, "b": 2}) + passValidator := func(_ any) error { return nil } + + err := ValidateConfig("test.path", data, passValidator, passValidator) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestValidateConfig_NoValidators(t *testing.T) { + t.Parallel() + + data, _ := json.Marshal(map[string]any{"a": 1}) + + err := ValidateConfig("test.path", data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +// --- RequiredKeys --- + +func TestRequiredKeys_AllPresent(t *testing.T) { + t.Parallel() + + validator := RequiredKeys("host", "port") + value := map[string]any{"host": "localhost", "port": 3000, "extra": true} + + err := validator(value) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestRequiredKeys_MissingKey(t *testing.T) { + t.Parallel() + + validator := RequiredKeys("host", "port") + value := map[string]any{"host": "localhost"} + + err := validator(value) + if err == nil { + t.Fatal("expected error for missing key") + } +} + +func TestRequiredKeys_NotAMap(t *testing.T) { + t.Parallel() + + validator := RequiredKeys("key") + + err := validator("not a map") + if err == nil { + t.Fatal("expected error for non-map value") + } +} + +// --- ConfigError --- + +func TestConfigError_Error(t *testing.T) { + t.Parallel() + + err := &ConfigError{ + Path: "test.path", + Reason: ConfigReasonMissing, + Err: errors.New("not found"), + } + + got := err.Error() + if got == "" { + t.Fatal("Error() returned empty string") + } +} + +func TestConfigError_Unwrap(t *testing.T) { + t.Parallel() + + inner := errors.New("inner error") + err := &ConfigError{ + Path: "test.path", + Reason: ConfigReasonInvalidFormat, + Err: inner, + } + + if !errors.Is(err, inner) { + t.Error("Unwrap should expose inner error via errors.Is") + } +} + +func TestConfigReason_String(t *testing.T) { + t.Parallel() + + tests := []struct { + reason ConfigReason + want string + }{ + {ConfigReasonMissing, "missing"}, + {ConfigReasonInvalidFormat, "invalid_format"}, + {ConfigReasonValidationFailed, "validation_failed"}, + {ConfigReason(99), "unknown"}, + } + + for _, tt := range tests { + if got := tt.reason.String(); got != tt.want { + t.Errorf("ConfigReason(%d).String() = %q, want %q", tt.reason, got, tt.want) + } + } +} + +func TestValidatePath_InvalidCharacters(t *testing.T) { + t.Parallel() + + for _, path := range []string{"extensions/myext", "extensions myext", ".badstart"} { + if err := validatePath(path); err == nil { + t.Errorf("validatePath(%q) expected error, got nil", path) + } + } +} diff --git a/cli/azd/pkg/azdext/keyvault_resolver.go b/cli/azd/pkg/azdext/keyvault_resolver.go new file mode 100644 index 00000000000..70bf4814ed5 --- /dev/null +++ b/cli/azd/pkg/azdext/keyvault_resolver.go @@ -0,0 +1,322 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "errors" + "fmt" + "net/http" + "regexp" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azsecrets" + "github.com/azure/azure-dev/cli/azd/pkg/keyvault" +) + +// KeyVaultResolver resolves Azure Key Vault secret references for extension +// scenarios. It uses the extension's [TokenProvider] for authentication and +// the Azure SDK data-plane client for secret retrieval. +// +// Secret references use the akvs:// URI scheme: +// +// akvs://// +// +// Usage: +// +// tp, _ := azdext.NewTokenProvider(ctx, client, nil) +// resolver, _ := azdext.NewKeyVaultResolver(tp, nil) +// value, err := resolver.Resolve(ctx, "akvs://sub-id/my-vault/my-secret") +type KeyVaultResolver struct { + credential azcore.TokenCredential + clientFactory secretClientFactory + opts KeyVaultResolverOptions +} + +// secretClientFactory abstracts secret client creation for testability. +type secretClientFactory func(vaultURL string, credential azcore.TokenCredential) (secretGetter, error) + +// secretGetter abstracts the Azure SDK secret client's GetSecret method. +type secretGetter interface { + GetSecret( + ctx context.Context, + name string, + version string, + options *azsecrets.GetSecretOptions, + ) (azsecrets.GetSecretResponse, error) +} + +// KeyVaultResolverOptions configures a [KeyVaultResolver]. +type KeyVaultResolverOptions struct { + // VaultSuffix overrides the default Key Vault DNS suffix. + // Defaults to "vault.azure.net" (Azure public cloud). + VaultSuffix string + + // ClientFactory overrides the default secret client constructor. + // Useful for testing. When nil, the production [azsecrets.NewClient] is used. + ClientFactory func(vaultURL string, credential azcore.TokenCredential) (secretGetter, error) +} + +// NewKeyVaultResolver creates a [KeyVaultResolver] with the given credential. +// +// credential must not be nil; it is typically a [*TokenProvider] from P1-1. +// If opts is nil, production defaults are used. +func NewKeyVaultResolver(credential azcore.TokenCredential, opts *KeyVaultResolverOptions) (*KeyVaultResolver, error) { + if credential == nil { + return nil, errors.New("azdext.NewKeyVaultResolver: credential must not be nil") + } + + if opts == nil { + opts = &KeyVaultResolverOptions{} + } + + if opts.VaultSuffix == "" { + opts.VaultSuffix = "vault.azure.net" + } + + factory := defaultSecretClientFactory + if opts.ClientFactory != nil { + factory = opts.ClientFactory + } + + return &KeyVaultResolver{ + credential: credential, + clientFactory: factory, + opts: *opts, + }, nil +} + +// defaultSecretClientFactory creates a real Azure SDK secrets client. +func defaultSecretClientFactory(vaultURL string, credential azcore.TokenCredential) (secretGetter, error) { + client, err := azsecrets.NewClient(vaultURL, credential, nil) + if err != nil { + return nil, err + } + + return client, nil +} + +// Resolve fetches the secret value for an akvs:// reference. +// +// The reference must match the format: akvs://// +// +// Returns a [*KeyVaultResolveError] for all domain errors (invalid reference, +// secret not found, authentication failure). No silent fallbacks or hidden retries. +func (r *KeyVaultResolver) Resolve(ctx context.Context, ref string) (string, error) { + if ctx == nil { + return "", errors.New("azdext.KeyVaultResolver.Resolve: context must not be nil") + } + + parsed, err := ParseSecretReference(ref) + if err != nil { + return "", &KeyVaultResolveError{ + Reference: ref, + Reason: ResolveReasonInvalidReference, + Err: err, + } + } + + vaultURL := fmt.Sprintf("https://%s.%s", parsed.VaultName, r.opts.VaultSuffix) + + client, err := r.clientFactory(vaultURL, r.credential) + if err != nil { + return "", &KeyVaultResolveError{ + Reference: ref, + Reason: ResolveReasonClientCreation, + Err: fmt.Errorf("failed to create Key Vault client for %s: %w", vaultURL, err), + } + } + + resp, err := client.GetSecret(ctx, parsed.SecretName, "", nil) + if err != nil { + // Default to ServiceError for non-HTTP errors (network timeout, DNS + // failure, etc.). AccessDenied is only used for 401/403 status codes. + reason := ResolveReasonServiceError + + var respErr *azcore.ResponseError + if errors.As(err, &respErr) { + switch respErr.StatusCode { + case http.StatusNotFound: + reason = ResolveReasonNotFound + case http.StatusForbidden, http.StatusUnauthorized: + reason = ResolveReasonAccessDenied + default: + reason = ResolveReasonServiceError + } + } + + return "", &KeyVaultResolveError{ + Reference: ref, + Reason: reason, + Err: fmt.Errorf( + "failed to retrieve secret %q from vault %q: %w", + parsed.SecretName, + parsed.VaultName, + err, + ), + } + } + + if resp.Value == nil { + return "", &KeyVaultResolveError{ + Reference: ref, + Reason: ResolveReasonNotFound, + Err: fmt.Errorf("secret %q in vault %q has a nil value", parsed.SecretName, parsed.VaultName), + } + } + + return *resp.Value, nil +} + +// ResolveMap resolves a map of key → akvs:// references, returning a map of +// key → resolved secret values. Processing stops at the first error. +// +// Non-akvs:// values are passed through unchanged, so callers can safely +// resolve a mixed map of plain values and secret references. +func (r *KeyVaultResolver) ResolveMap(ctx context.Context, refs map[string]string) (map[string]string, error) { + if ctx == nil { + return nil, errors.New("azdext.KeyVaultResolver.ResolveMap: context must not be nil") + } + + result := make(map[string]string, len(refs)) + + for key, value := range refs { + if !IsSecretReference(value) { + result[key] = value + continue + } + + resolved, err := r.Resolve(ctx, value) + if err != nil { + return nil, fmt.Errorf("azdext.KeyVaultResolver.ResolveMap: key %q: %w", key, err) + } + + result[key] = resolved + } + + return result, nil +} + +// SecretReference represents a parsed akvs:// URI. +type SecretReference struct { + // SubscriptionID is the Azure subscription containing the Key Vault. + SubscriptionID string + + // VaultName is the Key Vault name (not the full URL). + VaultName string + + // SecretName is the name of the secret within the vault. + SecretName string +} + +// IsSecretReference reports whether s uses the akvs:// scheme. +func IsSecretReference(s string) bool { + return keyvault.IsAzureKeyVaultSecret(s) +} + +// vaultNameRe validates Azure Key Vault names per Azure naming rules: +// - 3–24 characters +// - starts with a letter +// - contains only alphanumeric and hyphens +// - does not end with a hyphen +var vaultNameRe = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9-]{1,22}[a-zA-Z0-9]$`) + +// ParseSecretReference parses an akvs:// URI into its components. +// +// Expected format: akvs://// +// +// The vault name is validated against Azure Key Vault naming rules (3–24 +// characters, starts with letter, alphanumeric and hyphens only, does not +// end with a hyphen). +func ParseSecretReference(ref string) (*SecretReference, error) { + parsed, err := keyvault.ParseAzureKeyVaultSecret(ref) + if err != nil { + return nil, err + } + + if strings.TrimSpace(parsed.SubscriptionId) == "" { + return nil, fmt.Errorf("invalid akvs:// reference %q: subscription-id must not be empty", ref) + } + if strings.TrimSpace(parsed.VaultName) == "" { + return nil, fmt.Errorf("invalid akvs:// reference %q: vault-name must not be empty", ref) + } + if !vaultNameRe.MatchString(parsed.VaultName) { + return nil, fmt.Errorf( + "invalid akvs:// reference %q: vault name %q must be 3-24 characters, "+ + "start with a letter, and contain only alphanumeric characters and non-consecutive hyphens", + ref, parsed.VaultName, + ) + } + if strings.TrimSpace(parsed.SecretName) == "" { + return nil, fmt.Errorf("invalid akvs:// reference %q: secret-name must not be empty", ref) + } + + return &SecretReference{ + SubscriptionID: parsed.SubscriptionId, + VaultName: parsed.VaultName, + SecretName: parsed.SecretName, + }, nil +} + +// ResolveReason classifies the cause of a [KeyVaultResolveError]. +type ResolveReason int + +const ( + // ResolveReasonInvalidReference indicates the akvs:// URI is malformed. + ResolveReasonInvalidReference ResolveReason = iota + + // ResolveReasonClientCreation indicates failure to create the Key Vault client. + ResolveReasonClientCreation + + // ResolveReasonNotFound indicates the secret does not exist. + ResolveReasonNotFound + + // ResolveReasonAccessDenied indicates an authentication or authorization failure. + ResolveReasonAccessDenied + + // ResolveReasonServiceError indicates an unexpected Key Vault service error. + ResolveReasonServiceError +) + +// String returns a human-readable label for the reason. +func (r ResolveReason) String() string { + switch r { + case ResolveReasonInvalidReference: + return "invalid_reference" + case ResolveReasonClientCreation: + return "client_creation" + case ResolveReasonNotFound: + return "not_found" + case ResolveReasonAccessDenied: + return "access_denied" + case ResolveReasonServiceError: + return "service_error" + default: + return "unknown" + } +} + +// KeyVaultResolveError is returned when [KeyVaultResolver.Resolve] fails. +type KeyVaultResolveError struct { + // Reference is the original akvs:// URI that was being resolved. + Reference string + + // Reason classifies the failure. + Reason ResolveReason + + // Err is the underlying error. + Err error +} + +func (e *KeyVaultResolveError) Error() string { + return fmt.Sprintf( + "azdext.KeyVaultResolver: %s (ref=%s): %v", + e.Reason, e.Reference, e.Err, + ) +} + +func (e *KeyVaultResolveError) Unwrap() error { + return e.Err +} diff --git a/cli/azd/pkg/azdext/keyvault_resolver_test.go b/cli/azd/pkg/azdext/keyvault_resolver_test.go new file mode 100644 index 00000000000..ade5d0cf3f0 --- /dev/null +++ b/cli/azd/pkg/azdext/keyvault_resolver_test.go @@ -0,0 +1,576 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "errors" + "net/http" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azsecrets" +) + +// stubSecretGetter is a test double for the Key Vault data-plane client. +type stubSecretGetter struct { + resp azsecrets.GetSecretResponse + err error +} + +func (s *stubSecretGetter) GetSecret( + _ context.Context, _ string, _ string, _ *azsecrets.GetSecretOptions, +) (azsecrets.GetSecretResponse, error) { + return s.resp, s.err +} + +// stubSecretFactory returns a factory that always returns the given stubSecretGetter. +func stubSecretFactory(g secretGetter, factoryErr error) func(string, azcore.TokenCredential) (secretGetter, error) { + return func(_ string, _ azcore.TokenCredential) (secretGetter, error) { + if factoryErr != nil { + return nil, factoryErr + } + return g, nil + } +} + +// --- NewKeyVaultResolver --- + +func TestNewKeyVaultResolver_NilCredential(t *testing.T) { + t.Parallel() + + _, err := NewKeyVaultResolver(nil, nil) + if err == nil { + t.Fatal("expected error for nil credential") + } +} + +func TestNewKeyVaultResolver_Defaults(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, err := NewKeyVaultResolver(cred, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if resolver.opts.VaultSuffix != "vault.azure.net" { + t.Errorf("VaultSuffix = %q, want %q", resolver.opts.VaultSuffix, "vault.azure.net") + } +} + +func TestNewKeyVaultResolver_CustomSuffix(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, err := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + VaultSuffix: "vault.azure.cn", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if resolver.opts.VaultSuffix != "vault.azure.cn" { + t.Errorf("VaultSuffix = %q, want %q", resolver.opts.VaultSuffix, "vault.azure.cn") + } +} + +// --- IsSecretReference --- + +func TestIsSecretReference(t *testing.T) { + t.Parallel() + + tests := []struct { + input string + want bool + }{ + {"akvs://sub/vault/secret", true}, + {"akvs://", true}, + {"AKVS://sub/vault/secret", false}, // case-sensitive + {"https://vault.azure.net", false}, + {"", false}, + } + + for _, tt := range tests { + if got := IsSecretReference(tt.input); got != tt.want { + t.Errorf("IsSecretReference(%q) = %v, want %v", tt.input, got, tt.want) + } + } +} + +// --- ParseSecretReference --- + +func TestParseSecretReference_Valid(t *testing.T) { + t.Parallel() + + ref, err := ParseSecretReference("akvs://sub-123/my-vault/my-secret") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if ref.SubscriptionID != "sub-123" { + t.Errorf("SubscriptionID = %q, want %q", ref.SubscriptionID, "sub-123") + } + if ref.VaultName != "my-vault" { + t.Errorf("VaultName = %q, want %q", ref.VaultName, "my-vault") + } + if ref.SecretName != "my-secret" { + t.Errorf("SecretName = %q, want %q", ref.SecretName, "my-secret") + } +} + +func TestParseSecretReference_NotAkvsScheme(t *testing.T) { + t.Parallel() + + _, err := ParseSecretReference("https://vault.azure.net/secrets/x") + if err == nil { + t.Fatal("expected error for non-akvs scheme") + } +} + +func TestParseSecretReference_TooFewParts(t *testing.T) { + t.Parallel() + + _, err := ParseSecretReference("akvs://sub/vault") + if err == nil { + t.Fatal("expected error for two-part ref") + } +} + +func TestParseSecretReference_TooManyParts(t *testing.T) { + t.Parallel() + + _, err := ParseSecretReference("akvs://sub/vault/secret/extra") + if err == nil { + t.Fatal("expected error for four-part ref") + } +} + +func TestParseSecretReference_EmptyComponent(t *testing.T) { + t.Parallel() + + cases := []string{ + "akvs:///vault/secret", // empty subscription + "akvs://sub//secret", // empty vault + "akvs://sub/vault/", // empty secret + "akvs:// /vault/secret", // whitespace subscription + "akvs://sub/ /secret", // whitespace vault + "akvs://sub/vault/ ", // whitespace secret + } + + for _, ref := range cases { + _, err := ParseSecretReference(ref) + if err == nil { + t.Errorf("ParseSecretReference(%q) expected error, got nil", ref) + } + } +} + +// --- Resolve --- + +func TestResolve_Success(t *testing.T) { + t.Parallel() + + secretValue := "super-secret-value" + getter := &stubSecretGetter{ + resp: azsecrets.GetSecretResponse{ + Secret: azsecrets.Secret{ + Value: &secretValue, + }, + }, + } + + cred := &stubCredential{} + resolver, err := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + val, err := resolver.Resolve(context.Background(), "akvs://sub-id/my-vault/my-secret") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if val != secretValue { + t.Errorf("Resolve() = %q, want %q", val, secretValue) + } +} + +func TestResolve_NilContext(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(&stubSecretGetter{}, nil), + }) + + //nolint:staticcheck // intentionally testing nil context + _, err := resolver.Resolve(nil, "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for nil context") + } +} + +func TestResolve_InvalidReference(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(&stubSecretGetter{}, nil), + }) + + _, err := resolver.Resolve(context.Background(), "not-akvs://x") + if err == nil { + t.Fatal("expected error for invalid reference") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonInvalidReference { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonInvalidReference) + } +} + +func TestResolve_ClientCreationFailure(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(nil, errors.New("connection refused")), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for client creation failure") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonClientCreation { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonClientCreation) + } +} + +func TestResolve_SecretNotFound(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: &azcore.ResponseError{StatusCode: http.StatusNotFound}, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/missing-secret") + if err == nil { + t.Fatal("expected error for missing secret") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonNotFound { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonNotFound) + } +} + +func TestResolve_AccessDenied(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: &azcore.ResponseError{StatusCode: http.StatusForbidden}, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for forbidden access") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonAccessDenied { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonAccessDenied) + } +} + +func TestResolve_Unauthorized(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: &azcore.ResponseError{StatusCode: http.StatusUnauthorized}, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for unauthorized access") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonAccessDenied { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonAccessDenied) + } +} + +func TestResolve_ServiceError(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: &azcore.ResponseError{StatusCode: http.StatusInternalServerError}, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for server error") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonServiceError { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonServiceError) + } +} + +func TestResolve_NilValue(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + resp: azsecrets.GetSecretResponse{ + Secret: azsecrets.Secret{ + Value: nil, + }, + }, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for nil secret value") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonNotFound { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonNotFound) + } +} + +func TestResolve_NonResponseError(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: errors.New("network timeout"), + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(context.Background(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for network failure") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + // Non-ResponseError defaults to service_error (not access_denied), + // since network/DNS/timeout errors are not auth issues. + if resolveErr.Reason != ResolveReasonServiceError { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonServiceError) + } +} + +// --- ResolveMap --- + +func TestResolveMap_MixedValues(t *testing.T) { + t.Parallel() + + secretValue := "resolved-secret" + getter := &stubSecretGetter{ + resp: azsecrets.GetSecretResponse{ + Secret: azsecrets.Secret{ + Value: &secretValue, + }, + }, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + input := map[string]string{ + "plain": "hello-world", + "secret": "akvs://sub/vault/secret", + } + + result, err := resolver.ResolveMap(context.Background(), input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result["plain"] != "hello-world" { + t.Errorf("result[plain] = %q, want %q", result["plain"], "hello-world") + } + + if result["secret"] != secretValue { + t.Errorf("result[secret] = %q, want %q", result["secret"], secretValue) + } +} + +func TestResolveMap_Empty(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(&stubSecretGetter{}, nil), + }) + + result, err := resolver.ResolveMap(context.Background(), map[string]string{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result) != 0 { + t.Errorf("len(result) = %d, want 0", len(result)) + } +} + +func TestResolveMap_ErrorStopsProcessing(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: &azcore.ResponseError{StatusCode: http.StatusNotFound}, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + input := map[string]string{ + "secret": "akvs://sub/vault/missing", + } + + _, err := resolver.ResolveMap(context.Background(), input) + if err == nil { + t.Fatal("expected error when resolution fails") + } +} + +func TestResolveMap_NilContext(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(&stubSecretGetter{}, nil), + }) + + //nolint:staticcheck // intentionally testing nil context + _, err := resolver.ResolveMap(nil, map[string]string{"k": "v"}) + if err == nil { + t.Fatal("expected error for nil context") + } +} + +// --- Error types --- + +func TestKeyVaultResolveError_Error(t *testing.T) { + t.Parallel() + + err := &KeyVaultResolveError{ + Reference: "akvs://sub/vault/secret", + Reason: ResolveReasonNotFound, + Err: errors.New("secret not found"), + } + + got := err.Error() + if got == "" { + t.Fatal("Error() returned empty string") + } +} + +func TestKeyVaultResolveError_Unwrap(t *testing.T) { + t.Parallel() + + inner := errors.New("inner error") + err := &KeyVaultResolveError{ + Reference: "akvs://sub/vault/secret", + Reason: ResolveReasonServiceError, + Err: inner, + } + + if !errors.Is(err, inner) { + t.Error("Unwrap should expose inner error via errors.Is") + } +} + +func TestResolveReason_String(t *testing.T) { + t.Parallel() + + tests := []struct { + reason ResolveReason + want string + }{ + {ResolveReasonInvalidReference, "invalid_reference"}, + {ResolveReasonClientCreation, "client_creation"}, + {ResolveReasonNotFound, "not_found"}, + {ResolveReasonAccessDenied, "access_denied"}, + {ResolveReasonServiceError, "service_error"}, + {ResolveReason(99), "unknown"}, + } + + for _, tt := range tests { + if got := tt.reason.String(); got != tt.want { + t.Errorf("ResolveReason(%d).String() = %q, want %q", tt.reason, got, tt.want) + } + } +} diff --git a/cli/azd/pkg/azdext/mcp_security.go b/cli/azd/pkg/azdext/mcp_security.go index c9c2122eaeb..27c0d9a6860 100644 --- a/cli/azd/pkg/azdext/mcp_security.go +++ b/cli/azd/pkg/azdext/mcp_security.go @@ -6,6 +6,7 @@ package azdext import ( "fmt" "net" + "net/http" "net/url" "os" "path/filepath" @@ -25,6 +26,10 @@ type MCPSecurityPolicy struct { blockedHosts map[string]bool // lookupHost is used for DNS resolution; override in tests. lookupHost func(string) ([]string, error) + // onBlocked is an optional callback invoked when a URL or path is blocked. + // Parameters: action ("url_blocked", "path_blocked", "redirect_blocked"), + // detail (human-readable explanation). Safe for concurrent use. + onBlocked func(action, detail string) } // NewMCPSecurityPolicy creates an empty security policy. @@ -111,6 +116,20 @@ func (p *MCPSecurityPolicy) ValidatePathsWithinBase(basePaths ...string) *MCPSec return p } +// OnBlocked registers a callback that is invoked whenever a URL, path, or +// redirect is blocked by the security policy. This enables security audit +// logging without coupling the policy to a specific logging framework. +// +// The callback receives an action tag ("url_blocked", "path_blocked", +// "redirect_blocked") and a human-readable detail string. It must be safe +// for concurrent invocation. +func (p *MCPSecurityPolicy) OnBlocked(fn func(action, detail string)) *MCPSecurityPolicy { + p.mu.Lock() + defer p.mu.Unlock() + p.onBlocked = fn + return p +} + // isLocalhostHost returns true if the host is localhost or a loopback address. func isLocalhostHost(host string) bool { h := strings.ToLower(host) @@ -125,8 +144,16 @@ func isLocalhostHost(host string) bool { // Returns an error describing the violation, or nil if allowed. func (p *MCPSecurityPolicy) CheckURL(rawURL string) error { p.mu.RLock() - defer p.mu.RUnlock() + onBlocked := p.onBlocked + err := p.checkURLCore(rawURL) + p.mu.RUnlock() + if err != nil && onBlocked != nil { + onBlocked("url_blocked", err.Error()) + } + return err +} +func (p *MCPSecurityPolicy) checkURLCore(rawURL string) error { u, err := url.Parse(rawURL) if err != nil { return fmt.Errorf("invalid URL: %w", err) @@ -251,10 +278,33 @@ func (p *MCPSecurityPolicy) checkIP(ip net.IP, originalHost string) error { // CheckPath validates a file path against the security policy. // Resolves symlinks and checks for directory traversal. +// +// Security note (TOCTOU): There is an inherent time-of-check to time-of-use +// gap between the symlink resolution performed here and the caller's +// subsequent file operation. An adversary with write access to the filesystem +// could create or modify a symlink between the check and the use. This is a +// fundamental limitation of path-based validation on POSIX systems. +// +// Mitigations callers should consider: +// - Use O_NOFOLLOW when opening files after validation (prevents symlink +// following at the final component). +// - Use file-descriptor-based approaches (openat2 with RESOLVE_BENEATH on +// Linux 5.6+) where possible. +// - Avoid writing to directories that untrusted users can modify. +// - Consider validating the opened fd's path post-open via /proc/self/fd/N +// or fstat. func (p *MCPSecurityPolicy) CheckPath(path string) error { p.mu.RLock() - defer p.mu.RUnlock() + onBlocked := p.onBlocked + err := p.checkPathCore(path) + p.mu.RUnlock() + if err != nil && onBlocked != nil { + onBlocked("path_blocked", err.Error()) + } + return err +} +func (p *MCPSecurityPolicy) checkPathCore(path string) error { if len(p.allowedBasePaths) == 0 { return nil } @@ -348,3 +398,47 @@ func resolveExistingPrefix(p string) string { } } } + +// --------------------------------------------------------------------------- +// Redirect SSRF protection +// --------------------------------------------------------------------------- + +// redirectBlockedHosts lists cloud metadata service endpoints that must never +// be the target of an HTTP redirect. +var redirectBlockedHosts = map[string]bool{ + "169.254.169.254": true, + "fd00:ec2::254": true, + "metadata.google.internal": true, + "100.100.100.200": true, +} + +// SSRFSafeRedirect is an [http.Client] CheckRedirect function that blocks +// redirects to private networks and cloud metadata endpoints. It prevents +// redirect-based SSRF attacks where an attacker-controlled URL redirects to +// an internal service. +// +// Usage: +// +// client := &http.Client{CheckRedirect: azdext.SSRFSafeRedirect} +func SSRFSafeRedirect(req *http.Request, via []*http.Request) error { + const maxRedirects = 10 + if len(via) >= maxRedirects { + return fmt.Errorf("stopped after %d redirects", maxRedirects) + } + + host := req.URL.Hostname() + + // Block redirects to known metadata endpoints. + if redirectBlockedHosts[strings.ToLower(host)] { + return fmt.Errorf("redirect to metadata endpoint %s blocked (SSRF protection)", host) + } + + // Block redirects to private/loopback IP addresses. + if ip := net.ParseIP(host); ip != nil { + if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsUnspecified() { + return fmt.Errorf("redirect to private/loopback IP %s blocked (SSRF protection)", ip) + } + } + + return nil +} diff --git a/cli/azd/pkg/azdext/pagination.go b/cli/azd/pkg/azdext/pagination.go index 6f95645686d..021a7f1d5cd 100644 --- a/cli/azd/pkg/azdext/pagination.go +++ b/cli/azd/pkg/azdext/pagination.go @@ -14,10 +14,19 @@ import ( "strings" ) +const ( + // defaultMaxPages is the default upper bound on pages fetched by [Pager.Collect]. + // Individual callers can override this via [PagerOptions.MaxPages]. + // A value of 0 means unlimited (no cap), which is the default for manual + // NextPage iteration. Collect uses this default when MaxPages is unset. + defaultMaxPages = 500 +) + const ( // maxPageResponseSize limits the maximum size of a single page response // body to prevent excessive memory consumption from malicious or - // misconfigured servers. + // misconfigured servers. 10 MB is intentionally above typical Azure list + // payloads while still bounding memory use. maxPageResponseSize int64 = 10 << 20 // 10 MB // maxErrorBodySize limits the size of error response bodies captured @@ -44,6 +53,8 @@ type Pager[T any] struct { done bool opts PagerOptions originHost string // host of the initial URL for SSRF protection + pageCount int // number of pages fetched so far + truncated bool } // PageResponse is a single page returned by [Pager.NextPage]. @@ -59,6 +70,18 @@ type PageResponse[T any] struct { type PagerOptions struct { // Method overrides the HTTP method used for page requests. Defaults to GET. Method string + + // MaxPages limits the maximum number of pages that [Pager.Collect] will + // fetch. When set to a positive value, Collect stops after fetching that + // many pages. A value of 0 means unlimited (no cap) for manual NextPage + // iteration; Collect applies [defaultMaxPages] when this is 0. + MaxPages int + + // MaxItems limits the maximum total items that [Pager.Collect] will + // accumulate. When the collected items reach this count, Collect stops + // and returns the items gathered so far (truncated to MaxItems). + // A value of 0 means unlimited (no cap). + MaxItems int } // HTTPDoer abstracts the HTTP call so that [ResilientClient] or any @@ -117,6 +140,11 @@ func (p *Pager[T]) More() bool { return !p.done && p.nextURL != "" } +// Truncated reports whether [Collect] stopped due to MaxPages or MaxItems limits. +func (p *Pager[T]) Truncated() bool { + return p.truncated +} + // NextPage fetches the next page of results. Returns an error if the request // fails, the response is not 2xx, or the body cannot be decoded. // @@ -145,7 +173,7 @@ func (p *Pager[T]) NextPage(ctx context.Context) (*PageResponse[T], error) { return nil, &PaginationError{ StatusCode: resp.StatusCode, URL: p.nextURL, - Body: string(body), + Body: sanitizeErrorBody(string(body)), } } @@ -170,6 +198,9 @@ func (p *Pager[T]) NextPage(ctx context.Context) (*PageResponse[T], error) { p.nextURL = page.NextLink } + // Track page count for MaxPages enforcement in Collect. + p.pageCount++ + return &page, nil } @@ -199,12 +230,23 @@ func (p *Pager[T]) validateNextLink(nextLink string) error { } // Collect is a convenience method that fetches all remaining pages and -// returns all items in a single slice. Use with caution on large result sets. +// returns all items in a single slice. +// +// To prevent unbounded memory growth from runaway pagination, Collect +// enforces [PagerOptions.MaxPages] (defaults to [defaultMaxPages] when +// unset) and [PagerOptions.MaxItems]. When either limit is reached, +// iteration stops and the items collected so far are returned. // // If NextPage returns both page data and an error (e.g. rejected nextLink), // the page data is included in the returned slice before returning the error. func (p *Pager[T]) Collect(ctx context.Context) ([]T, error) { var all []T + p.truncated = false + + maxPages := p.opts.MaxPages + if maxPages <= 0 { + maxPages = defaultMaxPages + } for p.More() { page, err := p.NextPage(ctx) @@ -214,21 +256,82 @@ func (p *Pager[T]) Collect(ctx context.Context) ([]T, error) { if err != nil { return all, err } + + // Enforce MaxItems: truncate and stop if exceeded. + if p.opts.MaxItems > 0 && len(all) >= p.opts.MaxItems { + if len(all) > p.opts.MaxItems { + all = all[:p.opts.MaxItems] + } + p.truncated = true + break + } + + // Enforce MaxPages: stop after collecting the configured number of pages. + if p.pageCount >= maxPages { + p.truncated = true + break + } } return all, nil } +// PaginationError is returned when a page request receives a non-2xx response. +// maxPaginationErrorBodyLen limits the response body length stored in +// PaginationError to prevent sensitive data leakage through error messages. +const maxPaginationErrorBodyLen = 1024 + // PaginationError is returned when a page request receives a non-2xx response. type PaginationError struct { StatusCode int URL string - Body string + // Body is a truncated, sanitized excerpt of the error response body for + // diagnostics. Capped at [maxPaginationErrorBodyLen] bytes and stripped + // of control characters to prevent log forging. + Body string } func (e *PaginationError) Error() string { return fmt.Sprintf( "azdext.Pager: page request returned HTTP %d (url=%s)", - e.StatusCode, e.URL, + e.StatusCode, redactURL(e.URL), ) } + +// sanitizeErrorBody truncates and strips control characters from an error +// response body to prevent log forging and sensitive data leakage. +func sanitizeErrorBody(body string) string { + if len(body) > maxPaginationErrorBodyLen { + body = body[:maxPaginationErrorBodyLen] + "...[truncated]" + } + return stripControlChars(body) +} + +// stripControlChars replaces ASCII control characters (except tab) with a +// space to prevent log forging via CR/LF injection or terminal escape sequences. +func stripControlChars(s string) string { + var b strings.Builder + b.Grow(len(s)) + for _, r := range s { + if r < 0x20 && r != '\t' { + b.WriteRune(' ') + } else if r == 0x7F { + b.WriteRune(' ') + } else { + b.WriteRune(r) + } + } + return b.String() +} + +// redactURL strips query parameters and fragments from a URL to avoid leaking +// tokens, SAS signatures, or other secrets in log/error messages. +func redactURL(rawURL string) string { + u, err := url.Parse(rawURL) + if err != nil { + return "" + } + u.RawQuery = "" + u.Fragment = "" + return u.String() +} diff --git a/cli/azd/pkg/azdext/pagination_test.go b/cli/azd/pkg/azdext/pagination_test.go index 1f8c3537279..58e80ef9f4c 100644 --- a/cli/azd/pkg/azdext/pagination_test.go +++ b/cli/azd/pkg/azdext/pagination_test.go @@ -7,6 +7,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "io" "net/http" "strings" @@ -481,3 +482,106 @@ func TestPager_CollectWithSSRFError(t *testing.T) { t.Errorf("all = %v, want [a b] (partial results before SSRF error)", all) } } + +func TestPager_TruncatedByMaxPages(t *testing.T) { + t.Parallel() + + var responses []*doerResponse + for i := 1; i <= 5; i++ { + nextLink := "" + if i < 5 { + nextLink = fmt.Sprintf("https://example.com/api?page=%d", i+1) + } + body := pageJSON([]int{i}, nextLink) + responses = append(responses, &doerResponse{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{}, + }, + }) + } + + doer := &mockDoer{responses: responses} + pager := NewPager[int](doer, "https://example.com/api?page=1", &PagerOptions{MaxPages: 3}) + + all, err := pager.Collect(context.Background()) + if err != nil { + t.Fatalf("Collect failed: %v", err) + } + + if len(all) != 3 { + t.Errorf("len(all) = %d, want 3", len(all)) + } + + if !pager.Truncated() { + t.Error("Truncated() = false, want true (stopped at MaxPages)") + } +} + +func TestPager_TruncatedByMaxItems(t *testing.T) { + t.Parallel() + + var responses []*doerResponse + for i := 0; i < 3; i++ { + items := []int{i*4 + 1, i*4 + 2, i*4 + 3, i*4 + 4} + nextLink := "" + if i < 2 { + nextLink = fmt.Sprintf("https://example.com/api?page=%d", i+2) + } + body := pageJSON(items, nextLink) + responses = append(responses, &doerResponse{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{}, + }, + }) + } + + doer := &mockDoer{responses: responses} + pager := NewPager[int](doer, "https://example.com/api?page=1", &PagerOptions{MaxItems: 5}) + + all, err := pager.Collect(context.Background()) + if err != nil { + t.Fatalf("Collect failed: %v", err) + } + + if len(all) != 5 { + t.Errorf("len(all) = %d, want 5", len(all)) + } + + if !pager.Truncated() { + t.Error("Truncated() = false, want true (stopped at MaxItems)") + } +} + +func TestPager_NotTruncatedOnNaturalEnd(t *testing.T) { + t.Parallel() + + body := pageJSON([]string{"a", "b"}, "") + doer := &mockDoer{ + responses: []*doerResponse{ + {resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{}, + }}, + }, + } + + pager := NewPager[string](doer, "https://example.com/api", nil) + + all, err := pager.Collect(context.Background()) + if err != nil { + t.Fatalf("Collect failed: %v", err) + } + + if len(all) != 2 { + t.Errorf("len(all) = %d, want 2", len(all)) + } + + if pager.Truncated() { + t.Error("Truncated() = true, want false (natural end)") + } +} diff --git a/cli/azd/pkg/azdext/resilient_http_client.go b/cli/azd/pkg/azdext/resilient_http_client.go index 2916b885334..75d2d329596 100644 --- a/cli/azd/pkg/azdext/resilient_http_client.go +++ b/cli/azd/pkg/azdext/resilient_http_client.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "math" + "math/rand/v2" "net/http" "strconv" "time" @@ -107,8 +108,9 @@ func NewResilientClient(tokenProvider azcore.TokenCredential, opts *ResilientCli return &ResilientClient{ httpClient: &http.Client{ - Transport: transport, - Timeout: opts.Timeout, + Transport: transport, + Timeout: opts.Timeout, + CheckRedirect: SSRFSafeRedirect, }, tokenProvider: tokenProvider, scopeDetector: sd, @@ -126,6 +128,13 @@ func (rc *ResilientClient) Do(ctx context.Context, method, url string, body io.R if ctx == nil { return nil, errors.New("azdext.ResilientClient.Do: context must not be nil") } + if body != nil && rc.opts.MaxRetries > 0 { + if _, ok := body.(io.ReadSeeker); !ok { + return nil, errors.New( + "azdext.ResilientClient.Do: request body does not implement io.ReadSeeker; " + + "retries require a seekable body (use bytes.NewReader or strings.NewReader)") + } + } var lastErr error var retryAfterOverride time.Duration @@ -240,7 +249,8 @@ func (rc *ResilientClient) backoff(attempt int) time.Duration { delay = rc.opts.MaxDelay } - return delay + jitter := 0.5 + rand.Float64()*0.5 + return time.Duration(float64(delay) * jitter) } // isRetryable returns true for status codes that indicate a transient failure. diff --git a/cli/azd/pkg/azdext/resilient_http_client_test.go b/cli/azd/pkg/azdext/resilient_http_client_test.go index 4ffd67af103..3e5cddc9896 100644 --- a/cli/azd/pkg/azdext/resilient_http_client_test.go +++ b/cli/azd/pkg/azdext/resilient_http_client_test.go @@ -530,9 +530,9 @@ func TestResilientClient_NonSeekableBodyRetryError(t *testing.T) { t.Errorf("error = %q, want mention of io.ReadSeeker", err.Error()) } - // Should have made exactly 1 attempt (first gets 503 → retry → fail on body check). - if attempts != 1 { - t.Errorf("attempts = %d, want 1 (fail before second attempt)", attempts) + // Should fail fast before any request. + if attempts != 0 { + t.Errorf("attempts = %d, want 0 (fail fast before any request)", attempts) } } @@ -620,3 +620,92 @@ func TestResilientClient_RetryAfterCapped(t *testing.T) { t.Errorf("retryAfterFromResponse() = %v, want %v (capping happens in Do)", got, 999999*time.Second) } } + +func TestResilientClient_BackoffJitter(t *testing.T) { + t.Parallel() + + rc := NewResilientClient(nil, &ResilientClientOptions{ + InitialDelay: 100 * time.Millisecond, + MaxDelay: 10 * time.Second, + }) + + seen := make(map[time.Duration]bool) + for range 20 { + d := rc.backoff(1) + seen[d] = true + if d < 50*time.Millisecond || d >= 100*time.Millisecond { + t.Errorf("backoff(1) = %v, want in [50ms, 100ms)", d) + } + } + if len(seen) < 2 { + t.Error("backoff jitter produced identical values across 20 calls") + } +} + +func TestResilientClient_NonSeekableBodyFailsFast(t *testing.T) { + t.Parallel() + + var attempts int + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + attempts++ + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("ok")), + Header: http.Header{}, + }, nil + }) + + rc := NewResilientClient(nil, &ResilientClientOptions{ + Transport: transport, + MaxRetries: 2, + InitialDelay: time.Millisecond, + }) + + body := io.NopCloser(strings.NewReader("payload")) + _, err := rc.Do(context.Background(), http.MethodPost, "https://example.com/api", body) + if err == nil { + t.Fatal("expected error for non-seekable body with retries enabled") + } + + if !strings.Contains(err.Error(), "io.ReadSeeker") { + t.Errorf("error = %q, want mention of io.ReadSeeker", err.Error()) + } + + if attempts != 0 { + t.Errorf("attempts = %d, want 0 (fail fast before any request)", attempts) + } +} + +func TestResilientClient_RetryAfterCappedInDo(t *testing.T) { + t.Parallel() + + var attempts int + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + attempts++ + h := http.Header{} + h.Set("retry-after", "999999") + return &http.Response{ + StatusCode: http.StatusTooManyRequests, + Body: io.NopCloser(strings.NewReader("throttled")), + Header: h, + }, nil + }) + + rc := NewResilientClient(nil, &ResilientClientOptions{ + Transport: transport, + MaxRetries: 1, + InitialDelay: time.Millisecond, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) + defer cancel() + + _, err := rc.Do(ctx, http.MethodGet, "https://example.com/api", nil) + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("expected context.DeadlineExceeded (proving cap was applied), got: %v", err) + } + + if attempts != 1 { + t.Errorf("attempts = %d, want 1", attempts) + } +}