diff --git a/client/assets/settings.go b/client/assets/settings.go index cb9fc9b2..82df01e2 100644 --- a/client/assets/settings.go +++ b/client/assets/settings.go @@ -1,11 +1,8 @@ package assets import ( - "encoding/json" "github.com/chainreactors/IoM-go/proto/client/clientpb" - "github.com/chainreactors/malice-network/helper/utils/configutil" - "io/ioutil" - "path/filepath" + "github.com/gookit/config/v2" ) //var ( @@ -20,10 +17,24 @@ type Settings struct { LocalRPCEnable bool `yaml:"localrpc_enable" config:"localrpc_enable" default:"false"` LocalRPCAddr string `yaml:"localrpc_addr" config:"localrpc_addr" default:"127.0.0.1:15004"` Github *GithubSetting `yaml:"github" config:"github"` + AI *AISettings `yaml:"ai" config:"ai"` //VtApiKey string `yaml:"vt_api_key" config:"vt_api_key" default:""` } +// AISettings holds configuration for AI assistant integration +type AISettings struct { + Enable bool `yaml:"enable" config:"enable" default:"false"` + Provider string `yaml:"provider" config:"provider" default:"openai"` // openai, claude + APIKey string `yaml:"api_key" config:"api_key" default:""` + Endpoint string `yaml:"endpoint" config:"endpoint" default:"https://api.openai.com/v1"` + Model string `yaml:"model" config:"model" default:"gpt-4"` + MaxTokens int `yaml:"max_tokens" config:"max_tokens" default:"1024"` + Timeout int `yaml:"timeout" config:"timeout" default:"30"` + HistorySize int `yaml:"history_size" config:"history_size" default:"20"` + OpsecCheck bool `yaml:"opsec_check" config:"opsec_check" default:"false"` // Enable AI OPSEC risk assessment +} + type GithubSetting struct { Repo string `yaml:"repo" config:"repo" default:""` Owner string `yaml:"owner" config:"owner" default:""` @@ -44,17 +55,24 @@ func (github *GithubSetting) ToProtobuf() *clientpb.GithubActionBuildConfig { } func LoadSettings() (*Settings, error) { - rootDir, _ := filepath.Abs(GetRootAppDir()) - //data, err := os.ReadFile(filepath.Join(rootDir, settingsFileName)) - //if err != nil { - // return defaultSettings(), err - //} - settings := defaultSettings() - err := configutil.LoadConfig(filepath.Join(rootDir, maliceProfile), settings) + setting, err := GetSetting() + if err == nil && setting != nil { + return setting, nil + } + + _, loadErr := LoadProfile() + if loadErr != nil { + return defaultSettings(), loadErr + } + + setting, err = GetSetting() if err != nil { return defaultSettings(), err } - return settings, nil + if setting == nil { + return defaultSettings(), nil + } + return setting, nil } func defaultSettings() *Settings { @@ -68,16 +86,67 @@ func defaultSettings() *Settings { } } +// setConfigs sets multiple config key-value pairs, returning the first error encountered. +func setConfigs(kvs [][2]interface{}) error { + for _, kv := range kvs { + if err := config.Set(kv[0].(string), kv[1]); err != nil { + return err + } + } + return nil +} + // SaveSettings - Save the current settings to disk func SaveSettings(settings *Settings) error { - rootDir, _ := filepath.Abs(GetRootAppDir()) if settings == nil { settings = defaultSettings() } - data, err := json.MarshalIndent(settings, "", " ") - if err != nil { + + // Ensure profile is loaded so we don't overwrite unrelated config sections. + if _, err := LoadProfile(); err != nil { + return err + } + + // Top-level settings + if err := setConfigs([][2]interface{}{ + {"settings.max_server_log_size", settings.MaxServerLogSize}, + {"settings.opsec_threshold", settings.OpsecThreshold}, + {"settings.mcp_enable", settings.McpEnable}, + {"settings.mcp_addr", settings.McpAddr}, + {"settings.localrpc_enable", settings.LocalRPCEnable}, + {"settings.localrpc_addr", settings.LocalRPCAddr}, + }); err != nil { return err } - err = ioutil.WriteFile(filepath.Join(rootDir, maliceProfile), data, 0600) - return err + + // Github settings + if settings.Github != nil { + if err := setConfigs([][2]interface{}{ + {"settings.github.repo", settings.Github.Repo}, + {"settings.github.owner", settings.Github.Owner}, + {"settings.github.token", settings.Github.Token}, + {"settings.github.workflow", settings.Github.Workflow}, + }); err != nil { + return err + } + } + + // AI settings + if settings.AI != nil { + if err := setConfigs([][2]interface{}{ + {"settings.ai.enable", settings.AI.Enable}, + {"settings.ai.provider", settings.AI.Provider}, + {"settings.ai.api_key", settings.AI.APIKey}, + {"settings.ai.endpoint", settings.AI.Endpoint}, + {"settings.ai.model", settings.AI.Model}, + {"settings.ai.max_tokens", settings.AI.MaxTokens}, + {"settings.ai.timeout", settings.AI.Timeout}, + {"settings.ai.history_size", settings.AI.HistorySize}, + {"settings.ai.opsec_check", settings.AI.OpsecCheck}, + }); err != nil { + return err + } + } + + return nil } diff --git a/client/cmd/cli/root.go b/client/cmd/cli/root.go index 86e34aec..e0c8b9d1 100644 --- a/client/cmd/cli/root.go +++ b/client/cmd/cli/root.go @@ -20,14 +20,21 @@ func rootCmd(con *core.Console) (*cobra.Command, error) { } cmd.TraverseChildren = true - // 添加 --mcp flag + // Add --mcp flag cmd.PersistentFlags().String("mcp", "", "enable MCP server with address (e.g., 127.0.0.1:5005)") - // 添加 --rpc flag + // Add --rpc flag cmd.PersistentFlags().String("rpc", "", "enable local gRPC server with address (e.g., 127.0.0.1:15004)") - bind := command.MakeBind(cmd, con, "golang") command.BindCommonCommands(bind) - cmd.PersistentPreRunE, cmd.PersistentPostRunE = command.ConsoleRunnerCmd(con, cmd) + // Setup console runner + originalPre, originalPost := command.ConsoleRunnerCmd(con, cmd) + cmd.PersistentPreRunE = func(c *cobra.Command, args []string) error { + if originalPre != nil { + return originalPre(c, args) + } + return nil + } + cmd.PersistentPostRunE = originalPost cmd.AddCommand(command.ImplantCmd(con)) carapace.Gen(cmd) diff --git a/client/cmd/genhelp/gen_help.go b/client/cmd/genhelp/gen_help.go index 94ea620d..cd3f3215 100644 --- a/client/cmd/genhelp/gen_help.go +++ b/client/cmd/genhelp/gen_help.go @@ -1,7 +1,13 @@ package main import ( + "bytes" "fmt" + "io" + "os" + "sort" + "strings" + "github.com/chainreactors/IoM-go/consts" "github.com/chainreactors/IoM-go/proto/services/clientrpc" "github.com/chainreactors/malice-network/client/assets" @@ -42,8 +48,6 @@ import ( "github.com/gookit/config/v2" "github.com/gookit/config/v2/yaml" "github.com/spf13/cobra" - "io" - "os" ) func init() { diff --git a/client/command/ai/analyze.go b/client/command/ai/analyze.go new file mode 100644 index 00000000..bcbb8570 --- /dev/null +++ b/client/command/ai/analyze.go @@ -0,0 +1,135 @@ +package ai + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/chainreactors/malice-network/client/assets" + "github.com/chainreactors/malice-network/client/core" + "github.com/spf13/cobra" +) + +// AnalyzeCmd handles the analyze command - analyzes errors and provides suggestions +func AnalyzeCmd(cmd *cobra.Command, con *core.Console, args []string) error { + settings, err := assets.LoadSettings() + if err != nil { + return fmt.Errorf("failed to load settings: %w", err) + } + + if settings.AI == nil || !settings.AI.Enable { + return fmt.Errorf("AI is not enabled. Use 'ai-config --enable' to enable it") + } + + if settings.AI.APIKey == "" { + return fmt.Errorf("AI API key is not configured. Use 'ai-config --api-key ' to set it") + } + + // Get the error to analyze + var errorText string + if len(args) > 0 { + errorText = strings.Join(args, " ") + } + + if errorText == "" { + return fmt.Errorf("please provide an error message to analyze. Usage: analyze ") + } + + // Get context + historySize := settings.AI.HistorySize + if historySize <= 0 { + historySize = 20 + } + history := con.GetRecentHistory(historySize) + + // Build session context if available + sessionContext := buildSessionContext(con) + + // Build the analysis prompt + prompt := buildAnalysisPrompt(errorText, history, sessionContext) + + aiClient := core.NewAIClient(settings.AI) + + timeout := settings.AI.Timeout + if timeout <= 0 { + timeout = 30 + } + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second) + defer cancel() + + fmt.Println("\nAnalyzing error...") + fmt.Println() + + // Use streaming for real-time output + response, err := aiClient.AskStream(ctx, prompt, nil, func(chunk string) { + fmt.Print(chunk) + }) + if err != nil { + return fmt.Errorf("AI analysis failed: %w", err) + } + + fmt.Println() + + // Parse command suggestions + commands := core.ParseCommandSuggestions(response) + if len(commands) > 0 { + fmt.Println("\nSuggested commands:") + for i, cmd := range commands { + fmt.Printf(" [%d] %s\n", i+1, cmd.Command) + } + } + + fmt.Println() + return nil +} + +func buildSessionContext(con *core.Console) string { + var sb strings.Builder + + session := con.GetInteractive() + if session != nil { + sb.WriteString(fmt.Sprintf("Current session: %s\n", session.SessionId)) + if session.Os != nil { + sb.WriteString(fmt.Sprintf("OS: %s %s\n", session.Os.Name, session.Os.Arch)) + } + if session.Process != nil { + sb.WriteString(fmt.Sprintf("Process: %s (PID: %d)\n", session.Process.Name, session.Process.Pid)) + sb.WriteString(fmt.Sprintf("User: %s\n", session.Process.Owner)) + } + } else { + sb.WriteString("No active session\n") + } + + return sb.String() +} + +func buildAnalysisPrompt(errorText string, history []string, sessionContext string) string { + var sb strings.Builder + + sb.WriteString("Analyze the following error and provide:\n") + sb.WriteString("1. Possible causes of the error\n") + sb.WriteString("2. Suggested solutions or workarounds\n") + sb.WriteString("3. Alternative commands that might work\n\n") + + sb.WriteString("Error message:\n") + sb.WriteString(errorText) + sb.WriteString("\n\n") + + if sessionContext != "" { + sb.WriteString("Session context:\n") + sb.WriteString(sessionContext) + sb.WriteString("\n") + } + + if len(history) > 0 { + sb.WriteString("Recent command history:\n") + for _, cmd := range history { + sb.WriteString(fmt.Sprintf("- %s\n", cmd)) + } + } + + sb.WriteString("\nProvide a concise analysis. Wrap any command suggestions in backticks like `command`.") + + return sb.String() +} diff --git a/client/command/ai/ask.go b/client/command/ai/ask.go new file mode 100644 index 00000000..e2188948 --- /dev/null +++ b/client/command/ai/ask.go @@ -0,0 +1,80 @@ +package ai + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/chainreactors/malice-network/client/assets" + "github.com/chainreactors/malice-network/client/core" + "github.com/spf13/cobra" +) + +// AskCmd handles the ask command +func AskCmd(cmd *cobra.Command, con *core.Console, args []string) error { + question := strings.Join(args, " ") + if question == "" { + return fmt.Errorf("please provide a question") + } + + // Load settings + settings, err := assets.LoadSettings() + if err != nil { + return fmt.Errorf("failed to load settings: %w", err) + } + + if settings.AI == nil || !settings.AI.Enable { + return fmt.Errorf("AI is not enabled. Use 'ai-config --enable --api-key ' to enable it") + } + + if settings.AI.APIKey == "" { + return fmt.Errorf("AI API key is not configured. Use 'ai-config --api-key ' to set it") + } + + // Get history settings + historySize, _ := cmd.Flags().GetInt("history") + noHistory, _ := cmd.Flags().GetBool("no-history") + + var history []string + if !noHistory { + history = con.GetRecentHistory(historySize) + } + + // Create AI client + aiClient := core.NewAIClient(settings.AI) + + // Create context with timeout + timeout := settings.AI.Timeout + if timeout <= 0 { + timeout = 30 + } + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second) + defer cancel() + + fmt.Println("Thinking...") + + // Ask the AI + response, err := aiClient.Ask(ctx, question, history) + if err != nil { + return fmt.Errorf("AI error: %w", err) + } + + // Parse command suggestions + commands := core.ParseCommandSuggestions(response) + + // Display response + fmt.Printf("\n%s\n", response) + + // If there are command suggestions, list them + if len(commands) > 0 { + fmt.Println("\nSuggested commands:") + for i, cmd := range commands { + fmt.Printf(" [%d] %s\n", i+1, cmd.Command) + } + } + + fmt.Println() + + return nil +} diff --git a/client/command/ai/commands.go b/client/command/ai/commands.go new file mode 100644 index 00000000..c3b13db6 --- /dev/null +++ b/client/command/ai/commands.go @@ -0,0 +1,111 @@ +package ai + +import ( + "github.com/chainreactors/malice-network/client/core" + "github.com/spf13/cobra" +) + +// Commands returns all AI-related commands +func Commands(con *core.Console) []*cobra.Command { + aiConfigCmd := &cobra.Command{ + Use: "ai-config", + Short: "Configure AI assistant settings", + Long: "Configure the AI assistant with your preferred provider (OpenAI or Claude), API key, model, and other settings.", + RunE: func(cmd *cobra.Command, args []string) error { + return AIConfigCmd(cmd, con) + }, + Annotations: map[string]string{ + "static": "true", + }, + Example: `~~~ +// Enable AI with OpenAI +ai-config --enable --provider openai --api-key "sk-xxx" --model gpt-4 + +// Enable AI with Claude +ai-config --enable --provider claude --api-key "sk-ant-xxx" --endpoint "https://api.anthropic.com/v1" --model claude-3-opus-20240229 + +// Show current configuration +ai-config --show + +// Disable AI +ai-config --disable +~~~`, + } + + aiConfigCmd.Flags().Bool("enable", false, "Enable AI assistant") + aiConfigCmd.Flags().Bool("disable", false, "Disable AI assistant") + aiConfigCmd.Flags().Bool("show", false, "Show current AI configuration") + aiConfigCmd.Flags().String("provider", "", "AI provider: openai or claude") + aiConfigCmd.Flags().String("api-key", "", "API key for the AI provider") + aiConfigCmd.Flags().String("endpoint", "", "API endpoint URL") + aiConfigCmd.Flags().String("model", "", "Model name (e.g., gpt-4, claude-3-opus-20240229)") + aiConfigCmd.Flags().Int("max-tokens", 0, "Maximum tokens in response") + aiConfigCmd.Flags().Int("timeout", 0, "Request timeout in seconds") + aiConfigCmd.Flags().Int("history-size", 0, "Number of history lines to include as context") + aiConfigCmd.Flags().Bool("opsec-check", false, "Enable AI OPSEC risk assessment for high-risk commands") + + askCmd := &cobra.Command{ + Use: "ask [question]", + Short: "Ask the AI assistant a question", + Long: "Ask the AI assistant a question with command history context. This is equivalent to using '? ' syntax.", + Args: cobra.MinimumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + return AskCmd(cmd, con, args) + }, + Annotations: map[string]string{ + "static": "true", + }, + Example: `~~~ +// Ask about commands +ask how do I list all sessions + +// Ask about current target +ask what commands can I run on this target + +// Ask with no history context +ask --no-history how to download a file +~~~`, + } + + askCmd.Flags().Int("history", 20, "Number of history lines to include as context") + askCmd.Flags().Bool("no-history", false, "Don't include command history in context") + + questionCmd := &cobra.Command{ + Use: "? [question]", + Short: "Ask the AI assistant (shortcut)", + Long: "Ask the AI assistant a question. This is equivalent to using '? ' syntax or the 'ask' command.", + Args: cobra.MinimumNArgs(1), + Hidden: true, + RunE: func(cmd *cobra.Command, args []string) error { + return AskCmd(cmd, con, args) + }, + Annotations: map[string]string{ + "static": "true", + }, + } + + analyzeCmd := &cobra.Command{ + Use: "analyze [error message]", + Short: "AI-powered error analysis and suggestions", + Long: "Analyze an error message using AI and get suggestions for resolution, including possible causes and alternative commands.", + Args: cobra.MinimumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + return AnalyzeCmd(cmd, con, args) + }, + Annotations: map[string]string{ + "static": "true", + }, + Example: `~~~ +// Analyze an error message +analyze Access denied when trying to read file + +// Analyze with more context +analyze "Error: permission denied for /etc/shadow" + +// Analyze a command failure +analyze "getsystem failed: UAC is enabled" +~~~`, + } + + return []*cobra.Command{aiConfigCmd, askCmd, questionCmd, analyzeCmd} +} diff --git a/client/command/ai/config.go b/client/command/ai/config.go new file mode 100644 index 00000000..ba39ea85 --- /dev/null +++ b/client/command/ai/config.go @@ -0,0 +1,160 @@ +package ai + +import ( + "fmt" + "strings" + + "github.com/chainreactors/malice-network/client/assets" + "github.com/chainreactors/malice-network/client/core" + "github.com/spf13/cobra" +) + +// AIConfigCmd handles the ai-config command +func AIConfigCmd(cmd *cobra.Command, con *core.Console) error { + showConfig, _ := cmd.Flags().GetBool("show") + enableAI, _ := cmd.Flags().GetBool("enable") + disableAI, _ := cmd.Flags().GetBool("disable") + + // Load current settings + settings, err := assets.LoadSettings() + if err != nil { + return fmt.Errorf("failed to load settings: %w", err) + } + + // Initialize AI settings if nil + if settings.AI == nil { + settings.AI = &assets.AISettings{ + Enable: false, + Provider: "openai", + Endpoint: "https://api.openai.com/v1", + Model: "gpt-4", + MaxTokens: 1024, + Timeout: 30, + HistorySize: 20, + } + } + + // Show current config + if showConfig { + printAIConfig(settings.AI) + return nil + } + + // If no flags provided, show help + if !enableAI && !disableAI && !cmd.Flags().Changed("provider") && + !cmd.Flags().Changed("api-key") && !cmd.Flags().Changed("endpoint") && + !cmd.Flags().Changed("model") && !cmd.Flags().Changed("max-tokens") && + !cmd.Flags().Changed("timeout") && !cmd.Flags().Changed("history-size") { + printAIConfig(settings.AI) + fmt.Println("\nUse --help to see available options") + return nil + } + + // Update settings based on flags + if enableAI { + settings.AI.Enable = true + } + if disableAI { + settings.AI.Enable = false + } + + if provider, _ := cmd.Flags().GetString("provider"); provider != "" { + provider = strings.ToLower(provider) + if provider == "anthropic" { + provider = "claude" + } + if provider != "openai" && provider != "claude" { + return fmt.Errorf("invalid provider: %s. Must be 'openai' or 'claude'", provider) + } + settings.AI.Provider = provider + + // Set default endpoint based on provider + if !cmd.Flags().Changed("endpoint") { + if provider == "claude" { + settings.AI.Endpoint = "https://api.anthropic.com/v1" + } else { + settings.AI.Endpoint = "https://api.openai.com/v1" + } + } + } + + if apiKey, _ := cmd.Flags().GetString("api-key"); apiKey != "" { + settings.AI.APIKey = apiKey + } + + if endpoint, _ := cmd.Flags().GetString("endpoint"); endpoint != "" { + settings.AI.Endpoint = endpoint + } + + if model, _ := cmd.Flags().GetString("model"); model != "" { + settings.AI.Model = model + } + + if maxTokens, _ := cmd.Flags().GetInt("max-tokens"); maxTokens > 0 { + settings.AI.MaxTokens = maxTokens + } + + if timeout, _ := cmd.Flags().GetInt("timeout"); timeout > 0 { + settings.AI.Timeout = timeout + } + + if historySize, _ := cmd.Flags().GetInt("history-size"); historySize > 0 { + settings.AI.HistorySize = historySize + } + + if cmd.Flags().Changed("opsec-check") { + opsecCheck, _ := cmd.Flags().GetBool("opsec-check") + settings.AI.OpsecCheck = opsecCheck + } + + // Validate configuration if enabling + if settings.AI.Enable && settings.AI.APIKey == "" { + fmt.Println("Warning: AI is enabled but API key is not set. Use --api-key to set it.") + } + + // Save settings + if err := assets.SaveSettings(settings); err != nil { + return fmt.Errorf("failed to save settings: %w", err) + } + + fmt.Println("AI configuration updated successfully") + printAIConfig(settings.AI) + + return nil +} + +func printAIConfig(ai *assets.AISettings) { + fmt.Println("\nAI Configuration:") + fmt.Println("─────────────────────────────────────") + + enabledStr := "No" + if ai.Enable { + enabledStr = "Yes" + } + fmt.Printf(" Enabled: %s\n", enabledStr) + fmt.Printf(" Provider: %s\n", ai.Provider) + fmt.Printf(" Endpoint: %s\n", ai.Endpoint) + fmt.Printf(" Model: %s\n", ai.Model) + + // Mask API key + apiKeyDisplay := "(not set)" + if ai.APIKey != "" { + if len(ai.APIKey) > 8 { + apiKeyDisplay = ai.APIKey[:4] + "..." + ai.APIKey[len(ai.APIKey)-4:] + } else { + apiKeyDisplay = "****" + } + } + fmt.Printf(" API Key: %s\n", apiKeyDisplay) + + fmt.Printf(" Max Tokens: %d\n", ai.MaxTokens) + fmt.Printf(" Timeout: %ds\n", ai.Timeout) + fmt.Printf(" History Size: %d lines\n", ai.HistorySize) + + opsecCheckStr := "No" + if ai.OpsecCheck { + opsecCheckStr = "Yes" + } + fmt.Printf(" OPSEC Check: %s\n", opsecCheckStr) + fmt.Println() +} diff --git a/client/command/build/build-beacon.go b/client/command/build/build-beacon.go index 0602a55b..77a55d55 100644 --- a/client/command/build/build-beacon.go +++ b/client/command/build/build-beacon.go @@ -95,8 +95,7 @@ func BeaconCmd(cmd *cobra.Command, con *core.Console) error { if err != nil { return err } - executeBuild(con, buildConfig) - return nil + return ExecuteBuild(con, buildConfig) } // prepareBuildConfig 准备标准构建配置 diff --git a/client/command/build/build-module.go b/client/command/build/build-module.go index 458f7e53..4fa14542 100644 --- a/client/command/build/build-module.go +++ b/client/command/build/build-module.go @@ -49,8 +49,10 @@ func ModulesCmd(cmd *cobra.Command, con *core.Console) error { } else { mainProfile.Implant.Modules = strings.Split(modules, ",") } - buildConfig.MaleficConfig, _ = mainProfile.ToYAML() + buildConfig.MaleficConfig, err = mainProfile.ToYAML() + if err != nil { + return err + } - executeBuild(con, buildConfig) - return nil + return ExecuteBuild(con, buildConfig) } diff --git a/client/command/build/build-prelude.go b/client/command/build/build-prelude.go index aead3ae4..3e174979 100644 --- a/client/command/build/build-prelude.go +++ b/client/command/build/build-prelude.go @@ -42,6 +42,5 @@ func PreludeCmd(cmd *cobra.Command, con *core.Console) error { return err } - executeBuild(con, buildConfig) - return nil + return ExecuteBuild(con, buildConfig) } diff --git a/client/command/build/build-pulse.go b/client/command/build/build-pulse.go index 3d9c5d40..1813152b 100644 --- a/client/command/build/build-pulse.go +++ b/client/command/build/build-pulse.go @@ -42,9 +42,11 @@ func PulseCmd(cmd *cobra.Command, con *core.Console) error { return fmt.Errorf("failed to parse pulse's build flags: %w", err) } buildConfig.MaleficConfig, err = profile.ToYAML() + if err != nil { + return fmt.Errorf("failed to encode profile: %w", err) + } - executeBuild(con, buildConfig) - return nil + return ExecuteBuild(con, buildConfig) } func parsePulseBuildFlags(cmd *cobra.Command) (*implanttypes.ProfileConfig, error) { diff --git a/client/command/build/build.go b/client/command/build/build.go index 83d572a3..c21007e8 100644 --- a/client/command/build/build.go +++ b/client/command/build/build.go @@ -2,6 +2,7 @@ package build import ( "errors" + "fmt" "github.com/chainreactors/IoM-go/consts" "github.com/chainreactors/IoM-go/proto/client/clientpb" "github.com/chainreactors/malice-network/client/command/common" @@ -59,17 +60,15 @@ func parseSourceConfig(cmd *cobra.Command, con *core.Console, buildConfig *clien return buildConfig, nil } -// executeBuild 执行构建逻辑 -func executeBuild(con *core.Console, buildConfig *clientpb.BuildConfig) { - go func() { - artifact, err := con.Rpc.Build(con.Context(), buildConfig) - if err != nil { - con.Log.Errorf("Build %s failed: %v\n", buildConfig.BuildType, err) - return - } - con.Log.Infof("Build started: %s (type: %s, target: %s, source: %s)\n", - artifact.Name, artifact.Type, artifact.Target, artifact.Source) - }() +// ExecuteBuild executes the build logic. +func ExecuteBuild(con *core.Console, buildConfig *clientpb.BuildConfig) error { + artifact, err := con.Rpc.Build(con.Context(), buildConfig) + if err != nil { + return fmt.Errorf("build %s failed: %w", buildConfig.BuildType, err) + } + con.Log.Infof("Build started: %s (type: %s, target: %s, source: %s)\n", + artifact.Name, artifact.Type, artifact.Target, artifact.Source) + return nil } func BindCmd(cmd *cobra.Command, con *core.Console) error { @@ -78,13 +77,17 @@ func BindCmd(cmd *cobra.Command, con *core.Console) error { return err } - executeBuild(con, buildConfig) - return nil + return ExecuteBuild(con, buildConfig) } // parseLibFlag sets buildConfig.Lib based on the --lib flag and validates compatibility with buildType/target. func parseLibFlag(cmd *cobra.Command, buildConfig *clientpb.BuildConfig) error { libFlag, _ := cmd.Flags().GetBool("lib") + return ValidateLibFlag(buildConfig, libFlag, cmd.Flags().Changed("lib")) +} + +// ValidateLibFlag validates the lib flag and sets buildConfig.Lib. +func ValidateLibFlag(buildConfig *clientpb.BuildConfig, libFlag bool, libFlagChanged bool) error { target, ok := consts.GetBuildTarget(buildConfig.Target) if !ok { return errors.New("invalid target: " + buildConfig.Target) @@ -92,7 +95,7 @@ func parseLibFlag(cmd *cobra.Command, buildConfig *clientpb.BuildConfig) error { switch buildConfig.BuildType { case consts.CommandBuildModules, consts.CommandBuild3rdModules: - if cmd.Flags().Changed("lib") && !libFlag { + if libFlagChanged && !libFlag { return errors.New("modules build requires --lib") } if target.OS != consts.Windows { diff --git a/client/command/build/commands.go b/client/command/build/commands.go index 6e6b9a7a..4598070c 100644 --- a/client/command/build/commands.go +++ b/client/command/build/commands.go @@ -5,6 +5,7 @@ import ( "github.com/chainreactors/IoM-go/client" "github.com/chainreactors/IoM-go/consts" "github.com/chainreactors/malice-network/client/core" + "github.com/chainreactors/malice-network/client/wizard" "github.com/carapace-sh/carapace" "github.com/chainreactors/IoM-go/proto/client/clientpb" @@ -149,6 +150,9 @@ build beacon --profile tcp_default --target x86_64-pc-windows-gnu --source saas // Build by GithubAction build beacon --profile tcp_default --target x86_64-pc-windows-gnu --source action + +// Use interactive wizard mode +build beacon --wizard ~~~`, } common.BindFlag(beaconCmd, @@ -326,6 +330,9 @@ build log artifact_name --limit 70 }) common.BindArgCompletions(logCmd, nil, common.ArtifactCompleter(con)) + // Enable wizard for all build commands (except logCmd which doesn't need it) + common.EnableWizardForCommands(beaconCmd, bindCmd, modulesCmd, pulseCmd, preludeCmd) + buildCmd.AddCommand(beaconCmd, bindCmd, modulesCmd, pulseCmd, preludeCmd, logCmd) artifactCmd := &cobra.Command{ @@ -458,6 +465,9 @@ artifact delete --name artifact_name } func Register(con *core.Console) { + // Register wizard providers for dynamic options and defaults + registerWizardProviders(con) + con.EventCallback[consts.CtrlArtifactDownload] = func(event *clientpb.Event) { err := WriteOriginArtifact(con, event.Job.Name) if err != nil { @@ -631,3 +641,74 @@ func Register(con *core.Console) { Example: `artifact_payload("tcp_default","raw","windows","x64")`, }) } + +// registerWizardProviders registers dynamic option providers and default values for wizard +func registerWizardProviders(con *core.Console) { + // ============ Option Providers (for select fields) ============ + + // Profile options + wizard.RegisterProvider("profile", func() []string { + profiles, err := con.Rpc.GetProfiles(con.Context(), &clientpb.Empty{}) + if err != nil { + return nil + } + var opts []string + for _, p := range profiles.Profiles { + opts = append(opts, p.Name) + } + return opts + }) + + // Target options + wizard.RegisterProvider("target", func() []string { + return []string{ + "x86_64-pc-windows-gnu", + "x86_64-pc-windows-msvc", + "i686-pc-windows-gnu", + "i686-pc-windows-msvc", + "x86_64-unknown-linux-gnu", + "i686-unknown-linux-gnu", + } + }) + + // Source options + wizard.RegisterProvider("source", func() []string { + return []string{"local", "docker", "action"} + }) + + // ============ Default Value Providers ============ + + // Default target - most common choice + wizard.RegisterDefaultProvider("target", func() string { + return "x86_64-pc-windows-gnu" + }) + + // Default source - docker is most reliable + wizard.RegisterDefaultProvider("source", func() string { + return "docker" + }) + + // Default profile - use first available profile + wizard.RegisterDefaultProvider("profile", func() string { + profiles, err := con.Rpc.GetProfiles(con.Context(), &clientpb.Empty{}) + if err != nil || len(profiles.Profiles) == 0 { + return "" + } + return profiles.Profiles[0].Name + }) + + // Default jitter + wizard.RegisterDefaultProvider("jitter", func() string { + return "0.2" + }) + + // Default secure - enable by default + wizard.RegisterDefaultProvider("secure", func() string { + return "true" + }) + + // Default auto-download + wizard.RegisterDefaultProvider("auto-download", func() string { + return "true" + }) +} diff --git a/client/command/cert/commands.go b/client/command/cert/commands.go index a8beee16..10cf46f6 100644 --- a/client/command/cert/commands.go +++ b/client/command/cert/commands.go @@ -34,6 +34,8 @@ cert import --cert cert_file_path --key key_file_path --ca-cert ca_cert_path } common.BindFlag(importCmd, common.ImportSet) + _ = importCmd.MarkFlagRequired("cert") + _ = importCmd.MarkFlagRequired("key") common.BindFlagCompletions(importCmd, func(comp carapace.ActionMap) { comp["cert"] = carapace.ActionFiles().Usage("path to the cert file") comp["key"] = carapace.ActionFiles().Usage("path to the key file") @@ -141,6 +143,9 @@ cert download cert-name -o cert_path common.BindFlag(downloadCmd, func(f *pflag.FlagSet) { f.StringP("output", "o", "", "cert save path") }) + // Enable wizard for cert commands that need configuration + common.EnableWizardForCommands(importCmd, selfSignCmd, updateCmd) + certCmd.AddCommand(importCmd, selfSignCmd, delCmd, updateCmd, downloadCmd) //certCmd.AddCommand(importCmd, selfSignCmd, acmeCmd, delCmd, updateCmd, downloadCmd) return []*cobra.Command{ diff --git a/client/command/client.go b/client/command/client.go index 10aad0ab..6104fd86 100644 --- a/client/command/client.go +++ b/client/command/client.go @@ -1,7 +1,9 @@ package command import ( + "github.com/carapace-sh/carapace" "github.com/chainreactors/IoM-go/consts" + "github.com/chainreactors/malice-network/client/command/ai" "github.com/chainreactors/malice-network/client/command/audit" "github.com/chainreactors/malice-network/client/core" "github.com/reeflective/console" @@ -28,7 +30,8 @@ import ( func BindCommonCommands(bind BindFunc) { bind(consts.GenericGroup, - generic.Commands) + generic.Commands, + ai.Commands) bind(consts.ManageGroup, sessions.Commands, @@ -99,6 +102,9 @@ func BindClientsCommands(con *core.Console) console.Commands { client.SetHelpFunc(help.HelpFunc) client.SetHelpCommandGroupID(consts.GenericGroup) + // Register carapace completion for root command (make PersistentFlags visible in subcommands) + carapace.Gen(client) + RegisterClientFunc(con) RegisterImplantFunc(con) return client diff --git a/client/command/common/flagset.go b/client/command/common/flagset.go index ba4e3599..1c7f28e9 100644 --- a/client/command/common/flagset.go +++ b/client/command/common/flagset.go @@ -1,6 +1,9 @@ package common import ( + "errors" + "strings" + "github.com/chainreactors/IoM-go/proto/client/clientpb" "github.com/chainreactors/IoM-go/proto/implant/implantpb" "github.com/chainreactors/malice-network/helper/cryptography" @@ -333,29 +336,35 @@ func ParseImportCertFlags(cmd *cobra.Command) (*clientpb.TLS, error) { keyPath, _ := cmd.Flags().GetString("key") caPath, _ := cmd.Flags().GetString("ca-cert") - var err error - var cert, key, ca string - if certPath != "" && keyPath != "" && caPath != "" { - cert, err = cryptography.ProcessPEM(certPath) - if err != nil { - return nil, err - } - key, err = cryptography.ProcessPEM(keyPath) - if err != nil { - return nil, err - } - ca, err = cryptography.ProcessPEM(caPath) - if err != nil { - return nil, err - } + certPath = strings.TrimSpace(certPath) + keyPath = strings.TrimSpace(keyPath) + caPath = strings.TrimSpace(caPath) + + if certPath == "" || keyPath == "" { + return nil, errors.New("cert and key are required") } - return &clientpb.TLS{ + + cert, err := cryptography.ProcessPEM(certPath) + if err != nil { + return nil, err + } + key, err := cryptography.ProcessPEM(keyPath) + if err != nil { + return nil, err + } + + tls := &clientpb.TLS{ Cert: &clientpb.Cert{ Cert: cert, Key: key, }, - Ca: &clientpb.Cert{ - Cert: ca, - }, - }, nil + } + if caPath != "" { + ca, err := cryptography.ProcessPEM(caPath) + if err != nil { + return nil, err + } + tls.Ca = &clientpb.Cert{Cert: ca} + } + return tls, nil } diff --git a/client/command/common/wizard.go b/client/command/common/wizard.go new file mode 100644 index 00000000..8f0f8b5a --- /dev/null +++ b/client/command/common/wizard.go @@ -0,0 +1,67 @@ +package common + +import ( + "fmt" + + "github.com/chainreactors/malice-network/client/wizard" + "github.com/spf13/cobra" +) + +// AddWizardFlag adds the --wizard flag to a command +func AddWizardFlag(cmd *cobra.Command) { + cmd.Flags().Bool("wizard", false, "Start interactive wizard mode") +} + +// WrapPreRunEWithWizard wraps a command's PreRunE to support wizard mode. +// Usage: cmd.PreRunE = common.WrapPreRunEWithWizard(originalPreRunE, originalPreRun) +func WrapPreRunEWithWizard( + originalPreRunE func(cmd *cobra.Command, args []string) error, + originalPreRun func(cmd *cobra.Command, args []string), +) func(cmd *cobra.Command, args []string) error { + return func(cmd *cobra.Command, args []string) error { + if wizardMode, _ := cmd.Flags().GetBool("wizard"); wizardMode { + if _, err := wizard.RunWizard(cmd); err != nil { + return fmt.Errorf("wizard failed: %w", err) + } + } + if originalPreRunE != nil { + return originalPreRunE(cmd, args) + } + if originalPreRun != nil { + originalPreRun(cmd, args) + } + return nil + } +} + +// WrapRunEWithWizard wraps a command's RunE to support wizard mode +// Usage: cmd.RunE = common.WrapRunEWithWizard(cmd, originalRunE) +func WrapRunEWithWizard(originalRunE func(cmd *cobra.Command, args []string) error) func(cmd *cobra.Command, args []string) error { + return func(cmd *cobra.Command, args []string) error { + if wizardMode, _ := cmd.Flags().GetBool("wizard"); wizardMode { + if _, err := wizard.RunWizard(cmd); err != nil { + return fmt.Errorf("wizard failed: %w", err) + } + } + return originalRunE(cmd, args) + } +} + +// EnableWizard adds --wizard flag and wraps PreRunE for a command +// This is a convenience function that combines AddWizardFlag and WrapPreRunEWithWizard +func EnableWizard(cmd *cobra.Command) { + if cmd.RunE == nil && cmd.Run == nil { + return + } + AddWizardFlag(cmd) + originalPreRunE := cmd.PreRunE + originalPreRun := cmd.PreRun + cmd.PreRunE = WrapPreRunEWithWizard(originalPreRunE, originalPreRun) +} + +// EnableWizardForCommands enables wizard for multiple commands +func EnableWizardForCommands(cmds ...*cobra.Command) { + for _, cmd := range cmds { + EnableWizard(cmd) + } +} diff --git a/client/command/config/commands.go b/client/command/config/commands.go index 3627a9d9..9ae63507 100644 --- a/client/command/config/commands.go +++ b/client/command/config/commands.go @@ -68,6 +68,9 @@ func Commands(con *core.Console) []*cobra.Command { notifyCmd.AddCommand(notifyUpdateCmd) + // Enable wizard for config commands that need configuration + common.EnableWizardForCommands(githubUpdateCmd, notifyUpdateCmd) + configCmd.AddCommand(configRefreshCmd, githubCmd, notifyCmd) return []*cobra.Command{configCmd} } diff --git a/client/command/explorer/commands.go b/client/command/explorer/commands.go index 8e0be913..c3da49f1 100644 --- a/client/command/explorer/commands.go +++ b/client/command/explorer/commands.go @@ -9,8 +9,9 @@ import ( func Commands(con *core.Console) []*cobra.Command { regCommand := &cobra.Command{ - Use: consts.CommandRegExplorer, - Short: "registry explorer", + Use: consts.CommandRegExplorer + " [hive\\path]", + Short: "Interactive registry explorer", + Long: "Explore registry keys and values interactively from a starting hive/path (e.g., HKEY_LOCAL_MACHINE\\SOFTWARE).", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { return regExplorerCmd(cmd, con) @@ -19,6 +20,10 @@ func Commands(con *core.Console) []*cobra.Command { "depend": consts.ModuleRegListKey, "thirdParty": "true", }, + Example: `~~~ +reg_explorer HKLM\\SOFTWARE +reg_explorer HKEY_CURRENT_USER\\Software +~~~`, } fileCmd := &cobra.Command{ diff --git a/client/command/generic/login.go b/client/command/generic/login.go index 08f18c59..fd954949 100644 --- a/client/command/generic/login.go +++ b/client/command/generic/login.go @@ -3,6 +3,7 @@ package generic import ( "errors" "fmt" + "strings" "github.com/chainreactors/malice-network/client/assets" "github.com/chainreactors/malice-network/client/core" @@ -28,11 +29,19 @@ func LoginCmd(cmd *cobra.Command, con *core.Console) error { con.RPCAddr = rpcAddr } - if filename := cmd.Flags().Arg(0); filename != "" { - return Login(con, filename) - } else if filename, _ := cmd.Flags().GetString("auth"); filename != "" { + // Prefer explicit --auth flag to avoid misinterpreting subcommand arguments + // (e.g. `build beacon`) as an auth file. + if filename, _ := cmd.Flags().GetString("auth"); filename != "" { return Login(con, filename) } + + // Only check Arg(0) as auth file for root command or login command + // Avoid treating subcommand arguments (e.g., 'beacon' in 'build beacon') as auth file + if cmd.Parent() == nil || cmd.Use == "client" || cmd.Use == "login" { + if filename := cmd.Flags().Arg(0); strings.HasSuffix(filename, ".auth") { + return Login(con, filename) + } + } files, err := assets.GetConfigs() if err != nil { return fmt.Errorf("error retrieving YAML files: %w", err) diff --git a/client/command/mutant/commands.go b/client/command/mutant/commands.go index 8d035a2a..cf3e6005 100644 --- a/client/command/mutant/commands.go +++ b/client/command/mutant/commands.go @@ -210,6 +210,9 @@ func Commands(con *core.Console) []*cobra.Command { // Add subcommands to mutant parent command (excluding donut) mutantCmd.AddCommand(srdiCmd, stripCmd, sigforgeCmd) + // Enable wizard for mutant commands that need configuration + common.EnableWizardForCommands(donutCmd, srdiCmd, stripCmd, sigforgeCmd) + // Return mutant as parent command and donut as standalone return []*cobra.Command{mutantCmd, donutCmd} } diff --git a/client/command/pipeline/commands.go b/client/command/pipeline/commands.go index 96163528..d0dde8c0 100644 --- a/client/command/pipeline/commands.go +++ b/client/command/pipeline/commands.go @@ -207,5 +207,8 @@ rem delete rem_test remCmd.AddCommand(listremCmd, newRemCmd, startRemCmd, stopRemCmd, deleteRemCmd) + // Enable wizard for pipeline commands + common.EnableWizardForCommands(tcpCmd, httpCmd, bindCmd, newRemCmd) + return []*cobra.Command{tcpCmd, httpCmd, bindCmd, remCmd} } diff --git a/client/command/privilege/commands.go b/client/command/privilege/commands.go index 6dce3149..76e3e518 100644 --- a/client/command/privilege/commands.go +++ b/client/command/privilege/commands.go @@ -1,6 +1,7 @@ package privilege import ( + "github.com/carapace-sh/carapace" "github.com/chainreactors/IoM-go/consts" "github.com/chainreactors/malice-network/client/command/common" "github.com/chainreactors/malice-network/client/core" @@ -10,7 +11,7 @@ import ( func Commands(con *core.Console) []*cobra.Command { runasCmd := &cobra.Command{ - Use: "runas --username [username] --domain [domain] --password [password] --program [program] --args [args] --use-profile --use-env --netonly", + Use: "runas --username [username] --domain [domain] --password [password] --path [path] --args [args] --use-profile --use-env --netonly", Short: "Run a program as another user", RunE: func(cmd *cobra.Command, args []string) error { return RunasCmd(cmd, con) @@ -21,7 +22,7 @@ func Commands(con *core.Console) []*cobra.Command { }, Example: `Run a program as a different user: ~~~ - sys runas --username admin --domain EXAMPLE --password admin123 --program /path/to/program --args "arg1 arg2" --use-profile --use-env + sys runas --username admin --domain EXAMPLE --password admin123 --path /path/to/program --args "arg1 arg2" --use-profile --use-env ~~~`, } @@ -35,6 +36,9 @@ func Commands(con *core.Console) []*cobra.Command { f.Bool("use-env", false, "Use user environment") f.Bool("netonly", false, "Use network credentials only") }) + common.BindFlagCompletions(runasCmd, func(comp carapace.ActionMap) { + comp["path"] = carapace.ActionFiles().Usage("path to the program to execute") + }) privsCmd := &cobra.Command{ Use: "privs", diff --git a/client/command/reg/commands.go b/client/command/reg/commands.go index 298bd4c3..153b903a 100644 --- a/client/command/reg/commands.go +++ b/client/command/reg/commands.go @@ -1,6 +1,7 @@ package reg import ( + "github.com/carapace-sh/carapace" "github.com/chainreactors/IoM-go/consts" "github.com/chainreactors/malice-network/client/core" "strings" @@ -74,6 +75,16 @@ func Commands(con *core.Console) []*cobra.Command { f.StringP("type", "t", "REG_SZ", "Value type (REG_SZ, REG_BINARY, REG_DWORD, REG_QWORD)") f.StringP("data", "d", "", "Data to set") }) + common.BindFlagCompletions(regAddCmd, func(comp carapace.ActionMap) { + comp["type"] = carapace.ActionValuesDescribed( + "REG_SZ", "String", + "REG_EXPAND_SZ", "Expandable string", + "REG_MULTI_SZ", "Multi-string", + "REG_BINARY", "Binary data", + "REG_DWORD", "32-bit number", + "REG_QWORD", "64-bit number", + ).Tag("registry value type") + }) regDeleteCmd := &cobra.Command{ Use: consts.SubCommandName(consts.ModuleRegDelete) + " --hive [hive] --path [path] --key [key]", diff --git a/client/command/service/commands.go b/client/command/service/commands.go index 2ecb922e..33bb0d2d 100644 --- a/client/command/service/commands.go +++ b/client/command/service/commands.go @@ -143,6 +143,9 @@ Control the start type and error control by providing appropriate values.`, serviceCmd.AddCommand(serviceListCmd, serviceCreateCmd, serviceStartCmd, serviceStopCmd, serviceQueryCmd, serviceDeleteCmd) + // Enable wizard for service commands that need configuration + common.EnableWizardForCommands(serviceCreateCmd) + return []*cobra.Command{serviceCmd} } diff --git a/client/command/sys/ps.go b/client/command/sys/ps.go index ef2e47d8..0c4e110a 100644 --- a/client/command/sys/ps.go +++ b/client/command/sys/ps.go @@ -48,7 +48,7 @@ func RegisterPsFunc(con *core.Console) { psSet := ctx.Spite.GetPsResponse() var ps []string for _, p := range psSet.GetProcesses() { - ps = append(ps, fmt.Sprintf("%s:%d:%d:%s:%s:%s:%s:%s", + ps = append(ps, fmt.Sprintf("%s:%d:%d:%s:%s:%s:%s", p.Name, p.Pid, p.Ppid, diff --git a/client/command/taskschd/commands.go b/client/command/taskschd/commands.go index 21d58473..d84264a4 100644 --- a/client/command/taskschd/commands.go +++ b/client/command/taskschd/commands.go @@ -171,6 +171,9 @@ func Commands(con *core.Console) []*cobra.Command { }) taskschdCmd.AddCommand(taskSchdListCmd, taskSchdCreateCmd, taskSchdStartCmd, taskSchdStopCmd, taskSchdDeleteCmd, taskSchdQueryCmd, taskSchdRunCmd) + // Enable wizard for taskschd commands that need configuration + common.EnableWizardForCommands(taskSchdCreateCmd) + return []*cobra.Command{taskschdCmd} } diff --git a/client/command/website/commands.go b/client/command/website/commands.go index e41cde47..62b225e2 100644 --- a/client/command/website/commands.go +++ b/client/command/website/commands.go @@ -203,6 +203,9 @@ website list-content web_test common.BindArgCompletions(websiteListContentCmd, nil, common.WebsiteCompleter(con)) + // Enable wizard for website commands that need configuration + common.EnableWizardForCommands(websiteCmd, websiteAddContentCmd, websiteUpdateContentCmd) + websiteCmd.AddCommand(websiteListCmd, websiteStartCmd, websiteStopCmd, websiteAddContentCmd, websiteUpdateContentCmd, websiteRemoveContentCmd, websiteListContentCmd) diff --git a/client/command/wizard_flag.go b/client/command/wizard_flag.go new file mode 100644 index 00000000..054bcbeb --- /dev/null +++ b/client/command/wizard_flag.go @@ -0,0 +1,38 @@ +package command + +import ( + "fmt" + + "github.com/chainreactors/malice-network/client/wizard" + "github.com/spf13/cobra" +) + +// WizardFlagName is the name of the global wizard flag +const WizardFlagName = "wizard" + +// AddWizardFlag adds the --wizard flag to a command +func AddWizardFlag(cmd *cobra.Command) { + cmd.Flags().Bool(WizardFlagName, false, "Start interactive wizard mode") +} + +// ShouldRunWizard checks if the command should run in wizard mode +func ShouldRunWizard(cmd *cobra.Command) bool { + wizardMode, _ := cmd.Flags().GetBool(WizardFlagName) + return wizardMode +} + +// RunWizardIfEnabled checks if wizard mode is enabled and runs it +// Returns true if wizard was run, false otherwise +func RunWizardIfEnabled(cmd *cobra.Command) (bool, error) { + if !ShouldRunWizard(cmd) { + return false, nil + } + + // Run wizard - this handles everything including applying results to flags + _, err := wizard.RunWizard(cmd) + if err != nil { + return true, fmt.Errorf("wizard failed: %w", err) + } + + return true, nil +} diff --git a/client/core/ai.go b/client/core/ai.go new file mode 100644 index 00000000..373ad8ac --- /dev/null +++ b/client/core/ai.go @@ -0,0 +1,564 @@ +package core + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "regexp" + "strings" + "time" + + "github.com/chainreactors/malice-network/client/assets" +) + +// AIClient handles communication with AI APIs (OpenAI and Claude) +type AIClient struct { + settings *assets.AISettings + client *http.Client +} + +// NewAIClient creates a new AI client +func NewAIClient(settings *assets.AISettings) *AIClient { + timeout := 30 + if settings != nil && settings.Timeout > 0 { + timeout = settings.Timeout + } + return &AIClient{ + settings: settings, + client: &http.Client{ + Timeout: time.Duration(timeout) * time.Second, + }, + } +} + +// Message represents a chat message +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// OpenAI API structures +type OpenAIChatRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + Stream bool `json:"stream,omitempty"` +} + +type OpenAIChatResponse struct { + ID string `json:"id"` + Choices []struct { + Message Message `json:"message"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Error *struct { + Message string `json:"message"` + Type string `json:"type"` + } `json:"error,omitempty"` +} + +// Claude API structures +type ClaudeChatRequest struct { + Model string `json:"model"` + MaxTokens int `json:"max_tokens"` + System string `json:"system,omitempty"` + Messages []ClaudeMessage `json:"messages"` +} + +type ClaudeMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type ClaudeChatResponse struct { + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []struct { + Type string `json:"type"` + Text string `json:"text"` + } `json:"content"` + StopReason string `json:"stop_reason"` + Error *struct { + Type string `json:"type"` + Message string `json:"message"` + } `json:"error,omitempty"` +} + +// CommandSuggestion represents a command extracted from AI response +type CommandSuggestion struct { + Command string + Description string +} + +// Ask sends a question to the AI with context +func (c *AIClient) Ask(ctx context.Context, question string, history []string) (string, error) { + if c.settings == nil || !c.settings.Enable { + return "", fmt.Errorf("AI is not enabled. Use 'ai-config --enable' to enable it") + } + + if c.settings.APIKey == "" { + return "", fmt.Errorf("AI API key is not configured. Use 'ai-config --api-key ' to set it") + } + + systemPrompt := c.buildSystemPrompt(history) + + switch strings.ToLower(c.settings.Provider) { + case "claude", "anthropic": + return c.askClaude(ctx, systemPrompt, question) + default: // openai and compatible + return c.askOpenAI(ctx, systemPrompt, question) + } +} + +func (c *AIClient) buildSystemPrompt(history []string) string { + var sb strings.Builder + sb.WriteString("You are an AI assistant for IoM (Malice Network), a C2 framework. ") + sb.WriteString("Help users with commands, security operations, and answer questions. ") + sb.WriteString("Be concise and provide actionable suggestions when possible.\n\n") + + sb.WriteString("When suggesting commands, wrap them in backticks like `command`. ") + sb.WriteString("This helps users identify executable commands.\n\n") + + sb.WriteString("IMPORTANT: Use EXACT command names as listed below. Do NOT use plural forms or variations. ") + sb.WriteString("For example, use `session` NOT `sessions`, use `listener` NOT `listeners`.\n\n") + + if len(history) > 0 { + sb.WriteString("Recent command history:\n") + for _, cmd := range history { + sb.WriteString(fmt.Sprintf("- %s\n", cmd)) + } + sb.WriteString("\n") + } + + sb.WriteString("Available commands (use these EXACT names):\n") + sb.WriteString("- session: List and manage sessions (NOT 'sessions')\n") + sb.WriteString("- listener: List listeners in server (NOT 'listeners')\n") + sb.WriteString("- use : Switch to a session\n") + sb.WriteString("- ps: List processes\n") + sb.WriteString("- ls, cd, pwd: File system navigation\n") + sb.WriteString("- download, upload: File transfer\n") + sb.WriteString("- execute, shell, run: Run commands on target\n") + sb.WriteString("- job: List jobs\n") + sb.WriteString("- pipeline: Manage pipelines\n") + sb.WriteString("- build: Build implants\n") + + return sb.String() +} + +// doRequest sends an HTTP POST request and returns the response body. +func (c *AIClient) doRequest(ctx context.Context, endpoint string, headers map[string]string, body []byte) ([]byte, int, error) { + httpReq, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewReader(body)) + if err != nil { + return nil, 0, fmt.Errorf("failed to create request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "application/json") + for k, v := range headers { + httpReq.Header.Set(k, v) + } + + resp, err := c.client.Do(httpReq) + if err != nil { + return nil, 0, fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, resp.StatusCode, fmt.Errorf("failed to read response: %w", err) + } + + return respBody, resp.StatusCode, nil +} + +// buildEndpoint constructs the API endpoint URL with the given suffix. +func (c *AIClient) buildEndpoint(suffix string) (string, error) { + base := strings.TrimSuffix(strings.TrimSpace(c.settings.Endpoint), "/") + if base == "" { + return "", fmt.Errorf("AI endpoint is not configured. Use 'ai-config --endpoint ' to set it") + } + if !strings.HasSuffix(base, suffix) { + return base + suffix, nil + } + return base, nil +} + +func (c *AIClient) askOpenAI(ctx context.Context, systemPrompt, question string) (string, error) { + return c.askOpenAIWith(ctx, systemPrompt, question, c.settings.MaxTokens, 0.7) +} + +func (c *AIClient) askOpenAIWith(ctx context.Context, systemPrompt, question string, maxTokens int, temperature float64) (string, error) { + if maxTokens <= 0 { + maxTokens = c.settings.MaxTokens + } + if temperature < 0 { + temperature = 0.7 + } + + req := OpenAIChatRequest{ + Model: c.settings.Model, + Messages: []Message{{Role: "system", Content: systemPrompt}, {Role: "user", Content: question}}, + MaxTokens: maxTokens, + Temperature: temperature, + } + + body, err := json.Marshal(req) + if err != nil { + return "", fmt.Errorf("failed to marshal request: %w", err) + } + + endpoint, err := c.buildEndpoint("/chat/completions") + if err != nil { + return "", err + } + + respBody, statusCode, err := c.doRequest(ctx, endpoint, map[string]string{ + "Authorization": "Bearer " + c.settings.APIKey, + }, body) + if err != nil { + return "", err + } + + var chatResp OpenAIChatResponse + if err := json.Unmarshal(respBody, &chatResp); err != nil { + if statusCode < 200 || statusCode >= 300 { + return "", fmt.Errorf("API error (%d): %s", statusCode, strings.TrimSpace(string(respBody))) + } + return "", fmt.Errorf("failed to parse response: %w", err) + } + + if statusCode < 200 || statusCode >= 300 { + if chatResp.Error != nil { + return "", fmt.Errorf("API error (%d): %s", statusCode, chatResp.Error.Message) + } + return "", fmt.Errorf("API error (%d): %s", statusCode, strings.TrimSpace(string(respBody))) + } + + if chatResp.Error != nil { + return "", fmt.Errorf("API error: %s", chatResp.Error.Message) + } + + if len(chatResp.Choices) == 0 { + return "", fmt.Errorf("no response from AI") + } + + return chatResp.Choices[0].Message.Content, nil +} + +func (c *AIClient) askClaude(ctx context.Context, systemPrompt, question string) (string, error) { + return c.askClaudeWith(ctx, systemPrompt, question, c.settings.MaxTokens) +} + +func (c *AIClient) askClaudeWith(ctx context.Context, systemPrompt, question string, maxTokens int) (string, error) { + if maxTokens <= 0 { + maxTokens = c.settings.MaxTokens + } + if maxTokens <= 0 { + maxTokens = 256 + } + + req := ClaudeChatRequest{ + Model: c.settings.Model, + MaxTokens: maxTokens, + System: systemPrompt, + Messages: []ClaudeMessage{{Role: "user", Content: question}}, + } + + body, err := json.Marshal(req) + if err != nil { + return "", fmt.Errorf("failed to marshal request: %w", err) + } + + endpoint, err := c.buildEndpoint("/messages") + if err != nil { + return "", err + } + + respBody, statusCode, err := c.doRequest(ctx, endpoint, map[string]string{ + "x-api-key": c.settings.APIKey, + "anthropic-version": "2023-06-01", + }, body) + if err != nil { + return "", err + } + + var chatResp ClaudeChatResponse + if err := json.Unmarshal(respBody, &chatResp); err != nil { + if statusCode < 200 || statusCode >= 300 { + return "", fmt.Errorf("API error (%d): %s", statusCode, strings.TrimSpace(string(respBody))) + } + return "", fmt.Errorf("failed to parse response: %w", err) + } + + if statusCode < 200 || statusCode >= 300 { + if chatResp.Error != nil { + return "", fmt.Errorf("API error (%d): %s", statusCode, chatResp.Error.Message) + } + return "", fmt.Errorf("API error (%d): %s", statusCode, strings.TrimSpace(string(respBody))) + } + + if chatResp.Error != nil { + return "", fmt.Errorf("API error: %s", chatResp.Error.Message) + } + + if len(chatResp.Content) == 0 { + return "", fmt.Errorf("no response from AI") + } + + var result strings.Builder + for _, content := range chatResp.Content { + if content.Type == "text" { + result.WriteString(content.Text) + } + } + + return result.String(), nil +} + +// ParseCommandSuggestions extracts command suggestions from AI response +// Commands are expected to be wrapped in backticks like `command` +func ParseCommandSuggestions(response string) []CommandSuggestion { + var suggestions []CommandSuggestion + + // Match single backtick commands: `command` + singlePattern := regexp.MustCompile("`([^`\n]+)`") + matches := singlePattern.FindAllStringSubmatch(response, -1) + + seen := make(map[string]bool) + for _, match := range matches { + if len(match) > 1 { + cmd := strings.TrimSpace(match[1]) + // Skip if it looks like code/variable rather than command + if strings.Contains(cmd, "=") || strings.HasPrefix(cmd, "$") { + continue + } + // Skip shell escape syntax (! prefix) + if strings.HasPrefix(cmd, "!") { + continue + } + if !seen[cmd] { + seen[cmd] = true + suggestions = append(suggestions, CommandSuggestion{ + Command: cmd, + Description: "", + }) + } + } + } + + return suggestions +} + +// FormatResponseWithCommands formats the AI response with numbered command suggestions +func FormatResponseWithCommands(response string, commands []CommandSuggestion) string { + if len(commands) == 0 { + return response + } + + var sb strings.Builder + sb.WriteString(response) + sb.WriteString("\n\n") + sb.WriteString("Suggested commands:\n") + + for i, cmd := range commands { + sb.WriteString(fmt.Sprintf(" [%d] %s\n", i+1, cmd.Command)) + } + + return sb.String() +} + +// OpenAI streaming response structures +type OpenAIStreamChunk struct { + Choices []struct { + Delta struct { + Content string `json:"content"` + } `json:"delta"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` +} + +// Claude streaming response structures +type ClaudeStreamEvent struct { + Type string `json:"type"` + Index int `json:"index,omitempty"` + Delta *struct { + Type string `json:"type"` + Text string `json:"text"` + } `json:"delta,omitempty"` +} + +// AskStream sends a question to the AI and streams the response +func (c *AIClient) AskStream(ctx context.Context, question string, history []string, onChunk func(chunk string)) (string, error) { + if c.settings == nil || !c.settings.Enable { + return "", fmt.Errorf("AI is not enabled. Use 'ai-config --enable' to enable it") + } + + if c.settings.APIKey == "" { + return "", fmt.Errorf("AI API key is not configured. Use 'ai-config --api-key ' to set it") + } + + systemPrompt := c.buildSystemPrompt(history) + + switch strings.ToLower(c.settings.Provider) { + case "claude", "anthropic": + return c.askClaudeStream(ctx, systemPrompt, question, onChunk) + default: // openai and compatible + return c.askOpenAIStream(ctx, systemPrompt, question, onChunk) + } +} + +func (c *AIClient) askOpenAIStream(ctx context.Context, systemPrompt, question string, onChunk func(chunk string)) (string, error) { + req := OpenAIChatRequest{ + Model: c.settings.Model, + Messages: []Message{{Role: "system", Content: systemPrompt}, {Role: "user", Content: question}}, + MaxTokens: c.settings.MaxTokens, + Temperature: 0.7, + Stream: true, + } + + body, err := json.Marshal(req) + if err != nil { + return "", fmt.Errorf("failed to marshal request: %w", err) + } + + endpoint, err := c.buildEndpoint("/chat/completions") + if err != nil { + return "", err + } + + httpReq, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewReader(body)) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "text/event-stream") + httpReq.Header.Set("Authorization", "Bearer "+c.settings.APIKey) + + resp, err := c.client.Do(httpReq) + if err != nil { + return "", fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + respBody, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("API error (%d): %s", resp.StatusCode, strings.TrimSpace(string(respBody))) + } + + var fullResponse strings.Builder + scanner := bufio.NewScanner(resp.Body) + + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") { + continue + } + + data := strings.TrimPrefix(line, "data: ") + if data == "[DONE]" { + break + } + + var chunk OpenAIStreamChunk + if err := json.Unmarshal([]byte(data), &chunk); err != nil { + continue + } + + if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" { + content := chunk.Choices[0].Delta.Content + fullResponse.WriteString(content) + if onChunk != nil { + onChunk(content) + } + } + } + + if err := scanner.Err(); err != nil { + return fullResponse.String(), fmt.Errorf("stream read error: %w", err) + } + + return fullResponse.String(), nil +} + +func (c *AIClient) askClaudeStream(ctx context.Context, systemPrompt, question string, onChunk func(chunk string)) (string, error) { + reqBody := map[string]interface{}{ + "model": c.settings.Model, + "max_tokens": c.settings.MaxTokens, + "system": systemPrompt, + "messages": []ClaudeMessage{{Role: "user", Content: question}}, + "stream": true, + } + + body, err := json.Marshal(reqBody) + if err != nil { + return "", fmt.Errorf("failed to marshal request: %w", err) + } + + endpoint, err := c.buildEndpoint("/messages") + if err != nil { + return "", err + } + + httpReq, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewReader(body)) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "text/event-stream") + httpReq.Header.Set("x-api-key", c.settings.APIKey) + httpReq.Header.Set("anthropic-version", "2023-06-01") + + resp, err := c.client.Do(httpReq) + if err != nil { + return "", fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + respBody, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("API error (%d): %s", resp.StatusCode, strings.TrimSpace(string(respBody))) + } + + var fullResponse strings.Builder + scanner := bufio.NewScanner(resp.Body) + + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") { + continue + } + + data := strings.TrimPrefix(line, "data: ") + + var event ClaudeStreamEvent + if err := json.Unmarshal([]byte(data), &event); err != nil { + continue + } + + if event.Type == "content_block_delta" && event.Delta != nil && event.Delta.Text != "" { + fullResponse.WriteString(event.Delta.Text) + if onChunk != nil { + onChunk(event.Delta.Text) + } + } + + if event.Type == "message_stop" { + break + } + } + + if err := scanner.Err(); err != nil { + return fullResponse.String(), fmt.Errorf("stream read error: %w", err) + } + + return fullResponse.String(), nil +} diff --git a/client/core/console.go b/client/core/console.go index 491a3aa0..3a423607 100644 --- a/client/core/console.go +++ b/client/core/console.go @@ -89,6 +89,19 @@ func (c *Console) NewConsole() { implant.Prompt().Primary = c.GetPrompt implant.AddInterrupt(io.EOF, repl.ExitImplantMenu) // Ctrl-D implant.AddHistorySourceFile("history", filepath.Join(assets.GetRootAppDir(), "implant_history")) + + // Register line hook to handle '?' prefix without space (e.g., '?hello' -> '?' 'hello') + iom.PreCmdRunLineHooks = append(iom.PreCmdRunLineHooks, func(args []string) ([]string, error) { + if len(args) > 0 && len(args[0]) > 1 && strings.HasPrefix(args[0], "?") { + // Split '?xxx' into '?' and 'xxx' + question := args[0][1:] + newArgs := make([]string, 0, len(args)+1) + newArgs = append(newArgs, "?", question) + newArgs = append(newArgs, args[1:]...) + return newArgs, nil + } + return args, nil + }) } func (c *Console) Start(bindCmds ...BindCmds) error { @@ -106,7 +119,7 @@ func (c *Console) Start(bindCmds ...BindCmds) error { c.App.Menu(consts.ClientMenu).Command = bindCmds[0](c)() c.App.Menu(consts.ImplantMenu).Command = bindCmds[1](c)() - // 所有命令注册完成后,安全地启动MCP服务器和Local RPC服务器 + // After all commands are registered, safely start MCP server and Local RPC server if c.Server != nil { c.InitMCPServer() c.InitLocalRPCServer() @@ -273,3 +286,53 @@ func (c *Console) AddCommandFuncHelper(cmdName string, funcName string, example }) } } + +func (c *Console) GetRecentHistory(limit int) []string { + if limit <= 0 || c == nil || c.App == nil { + return nil + } + + shell := c.App.Shell() + if shell == nil || shell.History == nil || shell.History.Current() == nil { + return nil + } + + hist := shell.History.Current() + count := hist.Len() + start := count - limit + if start < 0 { + start = 0 + } + + capacity := limit + if count-start < capacity { + capacity = count - start + } + history := make([]string, 0, capacity) + for i := start; i < count; i++ { + if line, err := hist.GetLine(i); err == nil && line != "" { + history = append(history, line) + } + } + + if len(history) > limit { + history = history[len(history)-limit:] + } + + return history +} + +func getValidAISettings() (*assets.AISettings, error) { + settings, err := assets.GetSetting() + if err != nil { + return nil, fmt.Errorf("failed to load settings: %w", err) + } + if settings == nil || settings.AI == nil || !settings.AI.Enable { + return nil, fmt.Errorf("AI not enabled. Use 'ai-config --enable --api-key ' to enable it") + } + if settings.AI.APIKey == "" { + return nil, fmt.Errorf("AI API key not configured. Use 'ai-config --api-key ' to set it") + } + + return settings.AI, nil +} diff --git a/client/core/localrpc_test.go b/client/core/localrpc_test.go index 72d1508a..a6ba5ec0 100644 --- a/client/core/localrpc_test.go +++ b/client/core/localrpc_test.go @@ -3,7 +3,9 @@ package core import ( "context" "encoding/json" + "net" "testing" + "time" "github.com/chainreactors/IoM-go/proto/services/localrpc" "github.com/chainreactors/malice-network/client/plugin" @@ -18,9 +20,26 @@ const ( // setupRPCClient creates a gRPC client connection to the test RPC server func setupRPCClient(t *testing.T) (localrpc.CommandServiceClient, *grpc.ClientConn) { - conn, err := grpc.Dial(testRPCAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + t.Helper() + + // These are integration tests; skip when no local RPC server is running. + if c, err := net.DialTimeout("tcp", testRPCAddr, 250*time.Millisecond); err != nil { + t.Skipf("Skipping: local RPC server not reachable at %s: %v", testRPCAddr, err) + } else { + _ = c.Close() + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + conn, err := grpc.DialContext( + ctx, + testRPCAddr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithBlock(), + ) if err != nil { - t.Fatalf("Failed to connect to RPC server at %s: %v", testRPCAddr, err) + t.Skipf("Skipping: failed to connect to RPC server at %s: %v", testRPCAddr, err) } client := localrpc.NewCommandServiceClient(conn) diff --git a/client/plugin/lua.go b/client/plugin/lua.go index f6170925..d918d5ef 100644 --- a/client/plugin/lua.go +++ b/client/plugin/lua.go @@ -20,6 +20,7 @@ import ( "github.com/chainreactors/IoM-go/types" "github.com/chainreactors/logs" "github.com/chainreactors/malice-network/client/assets" + "github.com/chainreactors/malice-network/client/command/common" "github.com/chainreactors/malice-network/helper/intermediate" "github.com/chainreactors/mals" ) @@ -414,6 +415,7 @@ func (plug *LuaPlugin) RegisterLuaFunction() { } logs.Log.Debugf("Registered Command: %s\n", cmd.Name) + common.EnableWizard(malCmd) plug.CMDs.SetCommand(name, malCmd) return malCmd, nil }, &mals.Helper{Group: intermediate.ClientGroup}) diff --git a/client/plugin/vm.go b/client/plugin/vm.go index e5ff26cc..2ae72b42 100644 --- a/client/plugin/vm.go +++ b/client/plugin/vm.go @@ -2,14 +2,15 @@ package plugin import ( "fmt" + "strings" + "sync" + "time" + "github.com/chainreactors/logs" "github.com/chainreactors/malice-network/helper/intermediate" "github.com/chainreactors/mals" lua "github.com/yuin/gopher-lua" "github.com/yuin/gopher-lua/parse" - "strings" - "sync" - "time" ) func NewLuaVM() *lua.LState { @@ -27,6 +28,7 @@ func NewLuaVM() *lua.LState { for name, fun := range intermediate.InternalFunctions.Package(intermediate.BuiltinPackage) { vm.SetGlobal(name, vm.NewFunction(mals.WrapFuncForLua(fun))) } + return vm } diff --git a/client/wizard/cobra.go b/client/wizard/cobra.go new file mode 100644 index 00000000..c4c2f785 --- /dev/null +++ b/client/wizard/cobra.go @@ -0,0 +1,602 @@ +package wizard + +import ( + "bytes" + "encoding/csv" + "fmt" + "sort" + "strconv" + "strings" + "sync" + + "github.com/spf13/cobra" + "github.com/spf13/pflag" +) + +// ============ Dynamic Providers ============ + +// OptionProvider returns dynamic options for a flag's select menu +type OptionProvider func() []string + +// DefaultProvider returns a dynamic default value for a flag +type DefaultProvider func() string + +var ( + optionProviders = make(map[string]OptionProvider) + optionProvidersMu sync.RWMutex + + defaultProviders = make(map[string]DefaultProvider) + defaultProvidersMu sync.RWMutex +) + +// RegisterProvider registers a dynamic option provider for a flag name +func RegisterProvider(flagName string, fn OptionProvider) { + optionProvidersMu.Lock() + defer optionProvidersMu.Unlock() + optionProviders[flagName] = fn +} + +// RegisterDefaultProvider registers a default value provider for a flag name +func RegisterDefaultProvider(flagName string, fn DefaultProvider) { + defaultProvidersMu.Lock() + defer defaultProvidersMu.Unlock() + defaultProviders[flagName] = fn +} + +func getOptionProvider(flagName string) (OptionProvider, bool) { + optionProvidersMu.RLock() + defer optionProvidersMu.RUnlock() + fn, ok := optionProviders[flagName] + return fn, ok +} + +func getDefaultProvider(flagName string) (DefaultProvider, bool) { + defaultProvidersMu.RLock() + defer defaultProvidersMu.RUnlock() + fn, ok := defaultProviders[flagName] + return fn, ok +} + +// RunWizard runs an interactive wizard for the given command's flags. +// It returns the collected values as a map, or an error if cancelled. +func RunWizard(cmd *cobra.Command) (map[string]any, error) { + result := make(map[string]any) + groups := buildFormGroups(cmd, result) + + if len(groups) == 0 { + return result, nil + } + + form := NewGroupedWizardForm(groups) + if err := form.Run(); err != nil { + return nil, err + } + + // Finalize number fields (convert string -> int) + finalizeResult(result, cmd) + + // Apply result back to flags + if err := ApplyResultToFlags(cmd, result); err != nil { + return nil, err + } + + return result, nil +} + +// buildFormGroups creates FormGroups from command flags +func buildFormGroups(cmd *cobra.Command, result map[string]any) []*FormGroup { + // Collect flags by group + groups := make(map[string][]*pflag.Flag) + var ungrouped []*pflag.Flag + groupOrder := make([]string, 0) + seen := make(map[string]bool) + + cmd.Flags().VisitAll(func(flag *pflag.Flag) { + if skipFlag(flag) { + return + } + if g := getFlagGroup(flag); g != "" { + groups[g] = append(groups[g], flag) + if !seen[g] { + groupOrder = append(groupOrder, g) + seen[g] = true + } + } else { + ungrouped = append(ungrouped, flag) + } + }) + + // Sort groups by order annotation + sort.SliceStable(groupOrder, func(i, j int) bool { + return getGroupOrder(groups[groupOrder[i]]) < getGroupOrder(groups[groupOrder[j]]) + }) + + var formGroups []*FormGroup + + // Add ungrouped flags as "General" group + if len(ungrouped) > 0 { + sortByOrder(ungrouped) + formGroups = append(formGroups, &FormGroup{ + Name: "general", + Title: "General", + Fields: flagsToFields(ungrouped, result), + }) + } + + // Add grouped flags + for _, name := range groupOrder { + flags := groups[name] + sortByOrder(flags) + formGroups = append(formGroups, &FormGroup{ + Name: sanitize(name), + Title: name, + Fields: flagsToFields(flags, result), + }) + } + + return formGroups +} + +// flagsToFields converts a slice of flags to FormFields +func flagsToFields(flags []*pflag.Flag, result map[string]any) []*FormField { + fields := make([]*FormField, 0, len(flags)) + for _, flag := range flags { + fields = append(fields, flagToField(flag, result)) + } + return fields +} + +// flagToField converts a single flag to a FormField +func flagToField(flag *pflag.Flag, result map[string]any) *FormField { + field := &FormField{ + Name: flag.Name, + Title: flag.Name, + Description: flag.Usage, + Required: isRequired(flag), + } + + // Get current/default value + val := flag.Value.String() + if v, ok := getDefaultFromAnnotation(flag); ok { + val = v + } + + // Determine field type + switch flag.Value.Type() { + case "bool": + field.Kind = KindConfirm + field.ConfirmVal = val == "true" + result[flag.Name] = &field.ConfirmVal + field.Value = &field.ConfirmVal + + case "int", "int8", "int16", "int32", "int64", "uint", "uint8", "uint16", "uint32", "uint64": + field.Kind = KindNumber + field.InputValue = val + field.Validate = intValidator(flag.Value.Type()) + result[flag.Name] = &field.InputValue + field.Value = &field.InputValue + + case "float32", "float64": + field.Kind = KindInput + field.InputValue = val + field.Validate = floatValidator(flag) + result[flag.Name] = &field.InputValue + field.Value = &field.InputValue + + default: + // Check for slice type + if sv, ok := flag.Value.(pflag.SliceValue); ok { + field.Kind = KindInput + field.InputValue = formatCSV(sv.GetSlice()) + field.Description = flag.Usage + " (comma-separated)" + result[flag.Name] = &field.InputValue + field.Value = &field.InputValue + break + } + + field.Kind = KindInput + field.InputValue = val + result[flag.Name] = &field.InputValue + field.Value = &field.InputValue + + // Check for textarea widget + if getWidget(flag) == "textarea" { + // Still KindInput, just noted + } + } + + // Check for enum options -> convert to Select + if opts := getOptions(flag); len(opts) > 0 { + field.Kind = KindSelect + field.Options = opts + + // Find selected index + selected := 0 + found := false + for i, opt := range opts { + if opt == val { + selected = i + found = true + break + } + } + // Preserve empty defaults if the options include an empty placeholder. + if !found && (val == "" || val == "(empty)") { + for i, opt := range opts { + if opt == "" || opt == "(empty)" { + selected = i + found = true + break + } + } + } + // If empty and no empty option exists, select first non-empty. + if !found && (val == "" || val == "(empty)") { + for i, opt := range opts { + if opt != "" && opt != "(empty)" { + selected = i + break + } + } + } + field.Selected = selected + + // Store as string pointer + strVal := opts[selected] + result[flag.Name] = &strVal + field.Value = &strVal + } + + return field +} + +// ApplyResultToFlags applies wizard results back to command flags +func ApplyResultToFlags(cmd *cobra.Command, result map[string]any) error { + for name, value := range result { + flag := cmd.Flags().Lookup(name) + if flag == nil { + flag = cmd.PersistentFlags().Lookup(name) + } + if flag == nil { + continue + } + + strVal := toString(value) + currentVal := flag.Value.String() + + // Handle slice flags specially + if sv, ok := flag.Value.(pflag.SliceValue); ok { + desired, err := parseCSV(strVal) + if err != nil { + return fmt.Errorf("invalid value for %s: %w", name, err) + } + if !sliceEqual(sv.GetSlice(), desired) { + if err := sv.Replace(desired); err != nil { + return fmt.Errorf("failed to set %s: %w", name, err) + } + flag.Changed = true + } + continue + } + + // Skip if value unchanged + if currentVal == strVal { + continue + } + + if err := flag.Value.Set(strVal); err != nil { + return fmt.Errorf("failed to set %s: %w", name, err) + } + flag.Changed = true + } + return nil +} + +// finalizeResult converts number string values to int +func finalizeResult(result map[string]any, cmd *cobra.Command) { + cmd.Flags().VisitAll(func(flag *pflag.Flag) { + switch flag.Value.Type() { + case "int", "int8", "int16", "int32", "int64", "uint", "uint8", "uint16", "uint32", "uint64": + if ptr, ok := result[flag.Name].(*string); ok && ptr != nil { + s := strings.TrimSpace(*ptr) + if s == "" { + return + } + if parsed, ok := parseNumber(flag.Value.Type(), s); ok { + result[flag.Name] = parsed + } + } + } + }) +} + +// ============ Helpers ============ + +var skipFlags = map[string]bool{"help": true, "wizard": true, "version": true} + +func skipFlag(flag *pflag.Flag) bool { + return skipFlags[flag.Name] || flag.Hidden +} + +func getFlagGroup(flag *pflag.Flag) string { + if flag.Annotations == nil { + return "" + } + if g, ok := flag.Annotations["ui:group"]; ok && len(g) > 0 { + return g[0] + } + if g, ok := flag.Annotations["group"]; ok && len(g) > 0 { + return g[0] + } + return "" +} + +func getGroupOrder(flags []*pflag.Flag) int { + min := 9999 + for _, f := range flags { + if o := getFlagOrder(f); o < min { + min = o + } + } + return min +} + +func getFlagOrder(flag *pflag.Flag) int { + if flag.Annotations == nil { + return 9999 + } + if o, ok := flag.Annotations["ui:order"]; ok && len(o) > 0 { + if n, err := strconv.Atoi(o[0]); err == nil { + return n + } + } + return 9999 +} + +func sortByOrder(flags []*pflag.Flag) { + sort.SliceStable(flags, func(i, j int) bool { + return getFlagOrder(flags[i]) < getFlagOrder(flags[j]) + }) +} + +func sanitize(name string) string { + s := strings.ToLower(name) + s = strings.ReplaceAll(s, " ", "_") + s = strings.ReplaceAll(s, "-", "_") + return s +} + +func isRequired(flag *pflag.Flag) bool { + if flag.Annotations == nil { + return false + } + if r, ok := flag.Annotations["ui:required"]; ok && len(r) > 0 { + return r[0] == "true" + } + if _, ok := flag.Annotations["cobra_annotation_bash_completion_one_required_flag"]; ok { + return true + } + return false +} + +func getDefaultFromAnnotation(flag *pflag.Flag) (string, bool) { + // 1. Check dynamic provider first + if provider, ok := getDefaultProvider(flag.Name); ok { + if val := provider(); val != "" { + return val, true + } + } + // 2. Check static annotation + if flag.Annotations != nil { + if d, ok := flag.Annotations["ui:default"]; ok && len(d) > 0 { + return d[0], true + } + } + return "", false +} + +func getWidget(flag *pflag.Flag) string { + if flag.Annotations == nil { + return "" + } + if w, ok := flag.Annotations["ui:widget"]; ok && len(w) > 0 { + return w[0] + } + return "" +} + +func getOptions(flag *pflag.Flag) []string { + // 1. Check dynamic provider first + if provider, ok := getOptionProvider(flag.Name); ok { + if opts := provider(); len(opts) > 0 { + return opts + } + } + // 2. Check static annotation + if flag.Annotations != nil { + if o, ok := flag.Annotations["ui:options"]; ok && len(o) > 0 { + return o + } + } + return nil +} + +func floatValidator(flag *pflag.Flag) func(string) error { + return func(s string) error { + s = strings.TrimSpace(s) + if s == "" { + return nil + } + if _, err := strconv.ParseFloat(s, 64); err != nil { + return fmt.Errorf("invalid number") + } + return nil + } +} + +func parseNumber(typeName, s string) (any, bool) { + switch typeName { + case "int": + n, err := strconv.ParseInt(s, 10, strconv.IntSize) + if err != nil { + return nil, false + } + return int(n), true + case "int8": + n, err := strconv.ParseInt(s, 10, 8) + if err != nil { + return nil, false + } + return int8(n), true + case "int16": + n, err := strconv.ParseInt(s, 10, 16) + if err != nil { + return nil, false + } + return int16(n), true + case "int32": + n, err := strconv.ParseInt(s, 10, 32) + if err != nil { + return nil, false + } + return int32(n), true + case "int64": + n, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return nil, false + } + return n, true + case "uint": + n, err := strconv.ParseUint(s, 10, strconv.IntSize) + if err != nil { + return nil, false + } + return uint(n), true + case "uint8": + n, err := strconv.ParseUint(s, 10, 8) + if err != nil { + return nil, false + } + return uint8(n), true + case "uint16": + n, err := strconv.ParseUint(s, 10, 16) + if err != nil { + return nil, false + } + return uint16(n), true + case "uint32": + n, err := strconv.ParseUint(s, 10, 32) + if err != nil { + return nil, false + } + return uint32(n), true + case "uint64": + n, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return nil, false + } + return n, true + default: + return nil, false + } +} + +func intValidator(typeName string) func(string) error { + return func(s string) error { + s = strings.TrimSpace(s) + if s == "" { + return nil + } + var err error + switch typeName { + case "int": + _, err = strconv.ParseInt(s, 10, strconv.IntSize) + case "int8": + _, err = strconv.ParseInt(s, 10, 8) + case "int16": + _, err = strconv.ParseInt(s, 10, 16) + case "int32": + _, err = strconv.ParseInt(s, 10, 32) + case "int64": + _, err = strconv.ParseInt(s, 10, 64) + case "uint": + _, err = strconv.ParseUint(s, 10, strconv.IntSize) + case "uint8": + _, err = strconv.ParseUint(s, 10, 8) + case "uint16": + _, err = strconv.ParseUint(s, 10, 16) + case "uint32": + _, err = strconv.ParseUint(s, 10, 32) + case "uint64": + _, err = strconv.ParseUint(s, 10, 64) + } + if err != nil { + return fmt.Errorf("please enter a valid number") + } + return nil + } +} + +func toString(v any) string { + switch val := v.(type) { + case *string: + if val == nil { + return "" + } + return *val + case *bool: + if val == nil { + return "false" + } + return strconv.FormatBool(*val) + case *int: + if val == nil { + return "0" + } + return strconv.Itoa(*val) + case int: + return strconv.Itoa(val) + case bool: + return strconv.FormatBool(val) + case string: + return val + default: + return fmt.Sprintf("%v", v) + } +} + +func formatCSV(vals []string) string { + if len(vals) == 0 { + return "" + } + b := &bytes.Buffer{} + w := csv.NewWriter(b) + _ = w.Write(vals) + w.Flush() + return strings.TrimSuffix(b.String(), "\n") +} + +func parseCSV(s string) ([]string, error) { + s = strings.TrimSpace(s) + if strings.HasPrefix(s, "[") && strings.HasSuffix(s, "]") { + s = strings.TrimSpace(s[1 : len(s)-1]) + } + if s == "" { + return []string{}, nil + } + r := csv.NewReader(strings.NewReader(s)) + r.FieldsPerRecord = -1 + return r.Read() +} + +func sliceEqual(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/client/wizard/grouped_form.go b/client/wizard/grouped_form.go new file mode 100644 index 00000000..a9749792 --- /dev/null +++ b/client/wizard/grouped_form.go @@ -0,0 +1,1004 @@ +package wizard + +import ( + "fmt" + "strconv" + "strings" + "sync" + + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/huh" + "github.com/charmbracelet/lipgloss" +) + +var ( + // lipglossInitOnce ensures we only initialize lipgloss background detection once + // to avoid OSC terminal queries that can conflict with readline input handling. + lipglossInitOnce sync.Once +) + +// FieldKind represents the type of field in the form +type FieldKind int + +const ( + KindSelect FieldKind = iota + KindMultiSelect + KindInput + KindConfirm + KindNumber +) + +// Styles - package-level style definitions to avoid recreation +var ( + styleTabActive = lipgloss.NewStyle(). + Bold(true). + Foreground(lipgloss.Color("0")). + Background(lipgloss.Color("212")). + Padding(0, 1) + styleTabInactive = lipgloss.NewStyle(). + Foreground(lipgloss.Color("250")). + Padding(0, 1) + styleTabCompleted = lipgloss.NewStyle(). + Foreground(lipgloss.Color("42")). + Padding(0, 1) + styleSeparator = lipgloss.NewStyle().Foreground(lipgloss.Color("240")) + styleError = lipgloss.NewStyle().Foreground(lipgloss.Color("9")).Bold(true) + styleHelp = lipgloss.NewStyle().Foreground(lipgloss.Color("240")) + + styleFocusedTitle = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("212")) + styleNormalTitle = lipgloss.NewStyle().Foreground(lipgloss.Color("250")) + styleDescription = lipgloss.NewStyle().Foreground(lipgloss.Color("240")).Italic(true) + styleSelectedOption = lipgloss.NewStyle(). + Bold(true). + Foreground(lipgloss.Color("0")). + Background(lipgloss.Color("212")). + Padding(0, 1) + styleUnselectedOption = lipgloss.NewStyle().Foreground(lipgloss.Color("250")).Padding(0, 1) + styleFocusedUnselected = lipgloss.NewStyle().Foreground(lipgloss.Color("255")).Padding(0, 1) + styleMultiSelectChecked = lipgloss.NewStyle().Foreground(lipgloss.Color("42")).Padding(0, 1) + styleInputFocused = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("212")).Padding(0, 1) + styleInputBlurred = lipgloss.NewStyle().Foreground(lipgloss.Color("250")).Padding(0, 1) +) + +// FormField represents a field that can be displayed in the form +type FormField struct { + Name string + Title string + Description string + Kind FieldKind + Options []string // For Select/MultiSelect + Selected int // For Select: current selection index + MultiSelect map[int]bool // For MultiSelect: selected indices + InputValue string // For Input/Number + ConfirmVal bool // For Confirm + Required bool + Validate func(string) error + Value interface{} // Pointer to store result +} + +// GroupedWizardForm is a wizard form with Tab navigation for groups +type GroupedWizardForm struct { + groups []*FormGroup + groupIndex int // Current group being edited + + // Current field within group + fieldIndex int + cursor int // Cursor within field options + + inputMode bool + inputBuf string + inputCurPos int + + width int + height int + theme *huh.Theme + quitting bool + aborted bool + + errMsg string +} + +// FormGroup represents a group of fields +type FormGroup struct { + Name string + Title string + Description string + Fields []*FormField + Optional bool // If true, this group can be collapsed + Expanded bool // If true and Optional, show fields; otherwise collapsed +} + +// NewGroupedWizardForm creates a new grouped wizard form +func NewGroupedWizardForm(groups []*FormGroup) *GroupedWizardForm { + return &GroupedWizardForm{ + groups: groups, + groupIndex: 0, + fieldIndex: 0, + cursor: 0, + width: 80, + theme: huh.ThemeCharm(), + } +} + +// WithTheme sets the theme +func (f *GroupedWizardForm) WithTheme(theme *huh.Theme) *GroupedWizardForm { + f.theme = theme + return f +} + +// Init implements tea.Model +func (f *GroupedWizardForm) Init() tea.Cmd { + f.initCursorForField() + return nil +} + +// Update implements tea.Model +func (f *GroupedWizardForm) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.KeyMsg: + // Handle input mode separately + if f.inputMode { + return f.handleInputMode(msg) + } + + key := msg.String() + + // Check if current group is a collapsed optional group + group := f.currentGroup() + isCollapsedOptional := group != nil && group.Optional && !group.Expanded + + // Number keys 1-9 for group navigation + if len(key) == 1 && key[0] >= '1' && key[0] <= '9' { + groupNum := int(key[0] - '1') + if groupNum < len(f.groups) { + f.errMsg = "" + f.saveCurrentField() + f.groupIndex = groupNum + f.fieldIndex = 0 + f.initCursorForField() + return f, nil + } + } + + switch key { + case "ctrl+c", "esc": + f.aborted = true + f.quitting = true + return f, tea.Quit + + case "tab": + // Next group + f.errMsg = "" + f.saveCurrentField() + f.nextGroup() + + case "shift+tab": + // Previous group + f.errMsg = "" + f.saveCurrentField() + f.prevGroup() + + case "up", "k": + if isCollapsedOptional { + break // No field navigation in collapsed group + } + f.errMsg = "" + f.saveCurrentField() + f.prevField() + + case "down", "j": + if isCollapsedOptional { + break // No field navigation in collapsed group + } + f.errMsg = "" + f.saveCurrentField() + f.nextField() + + case "left", "h": + if isCollapsedOptional { + break + } + f.errMsg = "" + f.prevOption() + + case "right", "l": + if isCollapsedOptional { + break + } + f.errMsg = "" + f.nextOption() + + case " ": + f.errMsg = "" + // Handle collapsed optional group - expand it + if isCollapsedOptional { + group.Expanded = true + f.fieldIndex = 0 + f.initCursorForField() + break + } + field := f.currentField() + if field == nil { + break + } + if field.Kind == KindMultiSelect { + f.toggleSelection() + } else if field.Kind == KindConfirm { + f.cursor = 1 - f.cursor + f.saveCurrentField() + } + + case "ctrl+d": + return f.trySubmit() + + case "enter": + // Handle collapsed optional group - expand it + if isCollapsedOptional { + f.errMsg = "" + group.Expanded = true + f.fieldIndex = 0 + f.initCursorForField() + break + } + field := f.currentField() + if field == nil { + return f.trySubmit() + } + if field.Kind == KindInput || field.Kind == KindNumber { + f.errMsg = "" + f.inputMode = true + f.inputBuf = field.InputValue + f.inputCurPos = len(f.inputBuf) + } else { + return f.trySubmit() + } + + case "c": + // Collapse current optional group if expanded + if group != nil && group.Optional && group.Expanded { + f.errMsg = "" + group.Expanded = false + f.fieldIndex = 0 + } + + case "a": + if isCollapsedOptional { + break + } + if f.currentField() != nil && f.currentField().Kind == KindMultiSelect { + f.errMsg = "" + f.selectAll() + } + + case "n": + if isCollapsedOptional { + break + } + field := f.currentField() + if field != nil { + if field.Kind == KindMultiSelect { + f.errMsg = "" + f.deselectAll() + } else if field.Kind == KindConfirm { + f.errMsg = "" + f.cursor = 1 + f.saveCurrentField() + } + } + + case "y": + if isCollapsedOptional { + break + } + if f.currentField() != nil && f.currentField().Kind == KindConfirm { + f.errMsg = "" + f.cursor = 0 + f.saveCurrentField() + } + } + + case tea.WindowSizeMsg: + f.width = msg.Width + f.height = msg.Height + } + + return f, nil +} + +// handleInputMode handles key events when in text input mode +func (f *GroupedWizardForm) handleInputMode(msg tea.KeyMsg) (tea.Model, tea.Cmd) { + switch msg.String() { + case "ctrl+c", "esc": + f.inputMode = false + f.inputBuf = "" + f.errMsg = "" + + case "enter": + field := f.currentField() + if field == nil { + f.inputMode = false + return f, nil + } + candidate := f.inputBuf + old := field.InputValue + field.InputValue = candidate + if err := f.validateField(field); err != nil { + field.InputValue = old + f.errMsg = err.Error() + return f, nil + } + f.saveCurrentField() + f.inputMode = false + f.inputBuf = "" + f.errMsg = "" + // Move to next field + if f.fieldIndex < len(f.currentGroup().Fields)-1 { + f.nextField() + } + + case "ctrl+d": + field := f.currentField() + if field == nil { + f.inputMode = false + return f.trySubmit() + } + candidate := f.inputBuf + old := field.InputValue + field.InputValue = candidate + if err := f.validateField(field); err != nil { + field.InputValue = old + f.errMsg = err.Error() + return f, nil + } + f.saveCurrentField() + f.inputMode = false + f.inputBuf = "" + f.errMsg = "" + return f.trySubmit() + + case "backspace": + f.errMsg = "" + if len(f.inputBuf) > 0 { + f.inputBuf = f.inputBuf[:len(f.inputBuf)-1] + } + + default: + f.errMsg = "" + if len(msg.String()) == 1 { + f.inputBuf += msg.String() + } else if msg.Type == tea.KeySpace { + f.inputBuf += " " + } + } + + return f, nil +} + +// View implements tea.Model +func (f *GroupedWizardForm) View() string { + var sb strings.Builder + + // Tab bar - show required groups first, then optional groups + var tabs []string + for i, group := range f.groups { + label := fmt.Sprintf("%d.%s", i+1, group.Title) + + // Add indicator for optional groups + if group.Optional { + if group.Expanded { + label = fmt.Sprintf("%d.▼ %s", i+1, group.Title) + } else { + label = fmt.Sprintf("%d.▶ %s", i+1, group.Title) + } + } + + switch { + case i == f.groupIndex: + tabs = append(tabs, styleTabActive.Render(label)) + case group.Optional && !group.Expanded: + // Collapsed optional groups are "skipped": show them dimmed instead of as completed. + tabs = append(tabs, styleHelp.Render(label)) + case f.isGroupComplete(i): + tabs = append(tabs, styleTabCompleted.Render("✓ "+label)) + default: + tabs = append(tabs, styleTabInactive.Render(label)) + } + } + sb.WriteString(strings.Join(tabs, " ")) + sb.WriteString("\n") + + // Separator + sb.WriteString(styleSeparator.Render(strings.Repeat("─", minInt(f.width, 70)))) + sb.WriteString("\n\n") + + // Render current group + group := f.currentGroup() + if group == nil || len(group.Fields) == 0 { + sb.WriteString("(No fields in this group)\n") + } else if group.Optional && !group.Expanded { + // Collapsed optional group - show toggle prompt + sb.WriteString(styleDescription.Render(fmt.Sprintf(" %s (Optional)", group.Title))) + sb.WriteString("\n\n") + sb.WriteString(styleHelp.Render(" Press Enter or Space to expand, or Tab to skip")) + sb.WriteString("\n") + } else { + // Show all fields in current group + for i, field := range group.Fields { + sb.WriteString(f.renderField(field, i == f.fieldIndex)) + sb.WriteString("\n") + } + } + + // Error message + if strings.TrimSpace(f.errMsg) != "" { + sb.WriteString("\n") + sb.WriteString(styleError.Render("Error: " + f.errMsg)) + } + + // Help text + sb.WriteString("\n") + sb.WriteString(f.renderHelp()) + + return sb.String() +} + +// renderField renders a single field with all its options visible +func (f *GroupedWizardForm) renderField(field *FormField, isFocused bool) string { + var sb strings.Builder + + // Title with focus indicator + if isFocused { + sb.WriteString(styleFocusedTitle.Render("> " + field.Title)) + } else { + sb.WriteString(styleNormalTitle.Render(" " + field.Title)) + } + + // Description on same line if short + if field.Description != "" && len(field.Description) < 40 { + sb.WriteString(styleDescription.Render(" " + field.Description)) + } + sb.WriteString("\n") + + // Render options based on field kind + sb.WriteString(" ") + switch field.Kind { + case KindSelect: + sb.WriteString(f.renderSelectOptions(field, isFocused)) + case KindMultiSelect: + sb.WriteString(f.renderMultiSelectOptions(field, isFocused)) + case KindConfirm: + sb.WriteString(f.renderConfirmOptions(field, isFocused)) + case KindInput, KindNumber: + sb.WriteString(f.renderInputField(field, isFocused)) + } + + return sb.String() +} + +// selectOptionStyle returns the appropriate style based on focus and selection state +func selectOptionStyle(isFocused, isSelected bool) lipgloss.Style { + if isSelected { + return styleSelectedOption + } + if isFocused { + return styleFocusedUnselected + } + return styleUnselectedOption +} + +func (f *GroupedWizardForm) renderSelectOptions(field *FormField, isFocused bool) string { + parts := make([]string, 0, len(field.Options)) + for i, opt := range field.Options { + display := opt + if display == "" { + display = "(empty)" + } + style := selectOptionStyle(isFocused, i == field.Selected) + parts = append(parts, style.Render(display)) + } + return strings.Join(parts, " ") +} + +func (f *GroupedWizardForm) renderMultiSelectOptions(field *FormField, isFocused bool) string { + parts := make([]string, 0, len(field.Options)) + for i, opt := range field.Options { + marker := "○" + if field.MultiSelect[i] { + marker = "●" + } + display := fmt.Sprintf("%s %s", marker, opt) + isCursor := isFocused && i == f.cursor + + var style lipgloss.Style + switch { + case isCursor: + style = styleSelectedOption + case field.MultiSelect[i]: + style = styleMultiSelectChecked + case isFocused: + style = styleFocusedUnselected + default: + style = styleUnselectedOption + } + parts = append(parts, style.Render(display)) + } + return strings.Join(parts, " ") +} + +func (f *GroupedWizardForm) renderConfirmOptions(field *FormField, isFocused bool) string { + yesStyle := selectOptionStyle(isFocused, field.ConfirmVal) + noStyle := selectOptionStyle(isFocused, !field.ConfirmVal) + return yesStyle.Render("Yes") + " " + noStyle.Render("No") +} + +func (f *GroupedWizardForm) renderInputField(field *FormField, isFocused bool) string { + if isFocused && f.inputMode { + return styleInputFocused.Render("[" + f.inputBuf + "█]") + } + display := field.InputValue + if display == "" { + display = "(empty)" + } + if isFocused { + return styleInputFocused.Render("["+display+"]") + styleDescription.Render(" Enter to edit") + } + return styleInputBlurred.Render("[" + display + "]") +} + +func (f *GroupedWizardForm) renderHelp() string { + group := f.currentGroup() + + // Check if current group is a collapsed optional group + if group != nil && group.Optional && !group.Expanded { + return styleHelp.Render("Enter/Space: expand Tab: skip group 1-9: jump Ctrl+D: submit") + } + + // Check if current group is an expanded optional group + if group != nil && group.Optional && group.Expanded { + field := f.currentField() + baseHelp := "↑/↓: field c: collapse Tab: group " + if field == nil { + return styleHelp.Render(baseHelp + "Ctrl+D: submit") + } + switch field.Kind { + case KindMultiSelect: + return styleHelp.Render(baseHelp + "Space: toggle a: all Ctrl+D: submit") + case KindConfirm: + return styleHelp.Render(baseHelp + "←/→: toggle Ctrl+D: submit") + case KindInput, KindNumber: + if f.inputMode { + return styleHelp.Render("Enter: save Esc: cancel Ctrl+D: save & submit") + } + return styleHelp.Render(baseHelp + "Enter: edit Ctrl+D: submit") + default: + return styleHelp.Render(baseHelp + "←/→: select Ctrl+D: submit") + } + } + + field := f.currentField() + if field == nil { + return styleHelp.Render("Tab: next group Shift+Tab: prev group 1-9: jump to group Ctrl+D: submit") + } + + baseHelp := "↑/↓: field Tab/Shift+Tab: group 1-9: jump " + + switch field.Kind { + case KindMultiSelect: + return styleHelp.Render(baseHelp + "←/→: move Space: toggle a: all n: none Ctrl+D: submit") + case KindConfirm: + return styleHelp.Render(baseHelp + "←/→: toggle y: Yes n: No Ctrl+D: submit") + case KindInput, KindNumber: + if f.inputMode { + return styleHelp.Render("Enter: save Esc: cancel Ctrl+D: save & submit") + } + return styleHelp.Render(baseHelp + "Enter: edit Ctrl+D: submit") + default: + return styleHelp.Render(baseHelp + "←/→: select Ctrl+D: submit") + } +} + +// Helper methods + +func (f *GroupedWizardForm) currentGroup() *FormGroup { + if f.groupIndex >= 0 && f.groupIndex < len(f.groups) { + return f.groups[f.groupIndex] + } + return nil +} + +func (f *GroupedWizardForm) currentField() *FormField { + group := f.currentGroup() + if group == nil { + return nil + } + if f.fieldIndex >= 0 && f.fieldIndex < len(group.Fields) { + return group.Fields[f.fieldIndex] + } + return nil +} + +func (f *GroupedWizardForm) nextGroup() { + f.groupIndex++ + if f.groupIndex >= len(f.groups) { + f.groupIndex = 0 + } + f.fieldIndex = 0 + f.initCursorForField() +} + +func (f *GroupedWizardForm) prevGroup() { + f.groupIndex-- + if f.groupIndex < 0 { + f.groupIndex = len(f.groups) - 1 + } + f.fieldIndex = 0 + f.initCursorForField() +} + +func (f *GroupedWizardForm) nextField() { + group := f.currentGroup() + if group == nil { + return + } + f.fieldIndex++ + if f.fieldIndex >= len(group.Fields) { + f.fieldIndex = 0 + } + f.initCursorForField() +} + +func (f *GroupedWizardForm) prevField() { + group := f.currentGroup() + if group == nil { + return + } + f.fieldIndex-- + if f.fieldIndex < 0 { + f.fieldIndex = len(group.Fields) - 1 + } + f.initCursorForField() +} + +func (f *GroupedWizardForm) initCursorForField() { + field := f.currentField() + if field == nil { + f.cursor = 0 + return + } + switch field.Kind { + case KindSelect: + f.cursor = field.Selected + case KindConfirm: + if field.ConfirmVal { + f.cursor = 0 + } else { + f.cursor = 1 + } + default: + f.cursor = 0 + } +} + +// wrapIndex wraps index in range [0, max) with cycling +func wrapIndex(index, delta, max int) int { + if max <= 0 { + return 0 + } + return (index + delta + max) % max +} + +func (f *GroupedWizardForm) nextOption() { + field := f.currentField() + if field == nil { + return + } + switch field.Kind { + case KindSelect: + f.cursor = wrapIndex(f.cursor, 1, len(field.Options)) + field.Selected = f.cursor + f.saveCurrentField() + case KindMultiSelect: + f.cursor = wrapIndex(f.cursor, 1, len(field.Options)) + case KindConfirm: + f.cursor = 1 - f.cursor + f.saveCurrentField() + case KindInput, KindNumber: + if !f.inputMode { + f.saveCurrentField() + f.nextField() + } + } +} + +func (f *GroupedWizardForm) prevOption() { + field := f.currentField() + if field == nil { + return + } + switch field.Kind { + case KindSelect: + f.cursor = wrapIndex(f.cursor, -1, len(field.Options)) + field.Selected = f.cursor + f.saveCurrentField() + case KindMultiSelect: + f.cursor = wrapIndex(f.cursor, -1, len(field.Options)) + case KindConfirm: + f.cursor = 1 - f.cursor + f.saveCurrentField() + case KindInput, KindNumber: + if !f.inputMode { + f.saveCurrentField() + f.prevField() + } + } +} + +func (f *GroupedWizardForm) ensureMultiSelect(field *FormField) { + if field.MultiSelect == nil { + field.MultiSelect = make(map[int]bool) + } +} + +func (f *GroupedWizardForm) toggleSelection() { + field := f.currentField() + if field == nil { + return + } + f.ensureMultiSelect(field) + field.MultiSelect[f.cursor] = !field.MultiSelect[f.cursor] + f.saveCurrentField() +} + +func (f *GroupedWizardForm) selectAll() { + field := f.currentField() + if field == nil { + return + } + f.ensureMultiSelect(field) + for i := range field.Options { + field.MultiSelect[i] = true + } + f.saveCurrentField() +} + +func (f *GroupedWizardForm) deselectAll() { + field := f.currentField() + if field == nil { + return + } + field.MultiSelect = make(map[int]bool) + f.saveCurrentField() +} + +func (f *GroupedWizardForm) saveCurrentField() { + field := f.currentField() + if field == nil { + return + } + + switch field.Kind { + case KindSelect: + if ptr, ok := field.Value.(*string); ok && ptr != nil { + if field.Selected >= 0 && field.Selected < len(field.Options) { + *ptr = field.Options[field.Selected] + } + } + case KindMultiSelect: + if ptr, ok := field.Value.(*[]string); ok && ptr != nil { + var selected []string + for i, opt := range field.Options { + if field.MultiSelect[i] { + selected = append(selected, opt) + } + } + *ptr = selected + } + case KindConfirm: + field.ConfirmVal = (f.cursor == 0) + if ptr, ok := field.Value.(*bool); ok && ptr != nil { + *ptr = field.ConfirmVal + } + case KindInput, KindNumber: + if ptr, ok := field.Value.(*string); ok && ptr != nil { + *ptr = field.InputValue + } + } +} + +func (f *GroupedWizardForm) isGroupComplete(groupIdx int) bool { + if groupIdx < 0 || groupIdx >= len(f.groups) { + return false + } + group := f.groups[groupIdx] + + // Collapsed optional groups are considered "complete" (skipped) + if group.Optional && !group.Expanded { + return true + } + + for _, field := range group.Fields { + if err := f.validateField(field); err != nil { + return false + } + } + return true +} + +func (f *GroupedWizardForm) trySubmit() (tea.Model, tea.Cmd) { + f.saveCurrentField() + if err := f.validateAllFields(); err != nil { + return f, nil + } + f.quitting = true + return f, tea.Quit +} + +func (f *GroupedWizardForm) validateAllFields() error { + for gi, group := range f.groups { + // Skip collapsed optional groups (user chose to skip) + if group.Optional && !group.Expanded { + continue + } + + for fi, field := range group.Fields { + if err := f.validateField(field); err != nil { + f.errMsg = err.Error() + f.inputMode = false + f.inputBuf = "" + f.groupIndex = gi + f.fieldIndex = fi + f.initCursorForField() + return err + } + } + } + f.errMsg = "" + return nil +} + +// validateStringField validates string-like fields (Select, Input) +func (f *GroupedWizardForm) validateStringField(value string, field *FormField, label string) error { + if !field.Required && field.Validate == nil { + return nil + } + var required func(string) error + if field.Required { + required = requiredStringValidator(label) + } + return chainStringValidators(required, field.Validate)(value) +} + +// requiredError returns a formatted required error message +func requiredError(label string) error { + if label != "" { + return fmt.Errorf("%s is required", label) + } + return fmt.Errorf("value is required") +} + +func (f *GroupedWizardForm) validateField(field *FormField) error { + if field == nil { + return nil + } + + label := field.Title + if strings.TrimSpace(label) == "" { + label = field.Name + } + + switch field.Kind { + case KindSelect: + val := "" + if field.Selected >= 0 && field.Selected < len(field.Options) { + val = field.Options[field.Selected] + } + return f.validateStringField(val, field, label) + + case KindMultiSelect: + if !field.Required { + return nil + } + for _, selected := range field.MultiSelect { + if selected { + return nil + } + } + return requiredError(label) + + case KindInput: + return f.validateStringField(field.InputValue, field, label) + + case KindNumber: + s := strings.TrimSpace(field.InputValue) + if s == "" { + if field.Required { + return requiredError(label) + } + return nil + } + if field.Validate != nil { + if err := field.Validate(s); err != nil { + return err + } + return nil + } + if _, err := strconv.Atoi(s); err != nil { + return fmt.Errorf("please enter a valid number") + } + return nil + + case KindConfirm: + return nil + default: + return nil + } +} + +// Run executes the grouped form +func (f *GroupedWizardForm) Run() error { + // Prevent lipgloss from sending OSC terminal queries (like \x1b]11;?) + // which can conflict with readline's input handling and cause garbled output. + // We set HasDarkBackground once at startup to avoid runtime OSC queries. + lipglossInitOnce.Do(func() { + lipgloss.SetHasDarkBackground(true) + }) + + p := tea.NewProgram(f) + _, err := p.Run() + if err != nil { + return err + } + if f.aborted { + return fmt.Errorf("wizard aborted") + } + // Final save of all fields + for gi := range f.groups { + for fi := range f.groups[gi].Fields { + f.groupIndex = gi + f.fieldIndex = fi + f.initCursorForField() + f.saveCurrentField() + } + } + return nil +} + +// Aborted returns true if the user cancelled +func (f *GroupedWizardForm) Aborted() bool { + return f.aborted +} + +// requiredStringValidator creates a validator that checks for non-empty strings +func requiredStringValidator(label string) func(string) error { + return func(s string) error { + if strings.TrimSpace(s) == "" { + if label != "" { + return fmt.Errorf("%s is required", label) + } + return fmt.Errorf("value is required") + } + return nil + } +} + +// chainStringValidators chains multiple string validators together +func chainStringValidators(validators ...func(string) error) func(string) error { + return func(s string) error { + for _, v := range validators { + if v == nil { + continue + } + if err := v(s); err != nil { + return err + } + } + return nil + } +} + +func minInt(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/go.mod b/go.mod index 924b321a..4b95d8b3 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,8 @@ require ( github.com/chainreactors/utils v0.0.0-20241209140746-65867d2f78b2 github.com/charmbracelet/bubbletea v1.3.4 github.com/charmbracelet/glamour v0.8.0 + github.com/charmbracelet/huh v0.0.0-00010101000000-000000000000 + github.com/charmbracelet/lipgloss v1.1.0 github.com/corpix/uarand v0.2.0 github.com/dustin/go-humanize v1.0.1 github.com/evertras/bubble-table v0.17.2 @@ -73,6 +75,7 @@ require ( github.com/aymerick/douceur v0.2.0 // indirect github.com/blinkbean/dingtalk v1.1.3 // indirect github.com/carapace-sh/carapace-shlex v1.0.1 // indirect + github.com/catppuccin/go v0.2.0 // indirect github.com/cbroglie/mustache v1.4.0 // indirect github.com/chainreactors/fingers v0.0.0-20240702104653-a66e34aa41df // indirect github.com/chainreactors/go-metrics v0.0.0-20220926021830-24787b7a10f8 // indirect @@ -80,9 +83,9 @@ require ( github.com/charmbracelet/bubbles v0.20.0 // indirect github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect github.com/charmbracelet/harmonica v0.2.0 // indirect - github.com/charmbracelet/lipgloss v1.1.0 // indirect github.com/charmbracelet/x/ansi v0.8.0 // indirect github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect + github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 // indirect github.com/charmbracelet/x/term v0.2.1 // indirect github.com/cjoudrey/gluahttp v0.0.0-20201111170219-25003d9adfa9 // indirect github.com/davecgh/go-spew v1.1.1 // indirect @@ -121,6 +124,7 @@ require ( github.com/mattn/go-sqlite3 v1.14.24 // indirect github.com/microcosm-cc/bluemonday v1.0.27 // indirect github.com/miekg/dns v1.1.67 // indirect + github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/moby/term v0.5.0 // indirect github.com/montanaflynn/stats v0.7.1 // indirect @@ -165,10 +169,15 @@ require ( ) replace ( + dario.cat/mergo => github.com/imdario/mergo v0.3.16 + github.com/charmbracelet/huh => github.com/charmbracelet/huh v0.6.0 + github.com/miekg/dns => github.com/miekg/dns v1.1.58 golang.org/x/crypto => golang.org/x/crypto v0.24.0 + golang.org/x/mod => golang.org/x/mod v0.17.0 golang.org/x/net => golang.org/x/net v0.23.0 golang.org/x/sync => golang.org/x/sync v0.11.0 golang.org/x/sys => golang.org/x/sys v0.30.0 + golang.org/x/tools => golang.org/x/tools v0.21.0 ) replace ( diff --git a/helper/cryptography/cryptography.go b/helper/cryptography/cryptography.go index 82609eed..1c4b3bb4 100644 --- a/helper/cryptography/cryptography.go +++ b/helper/cryptography/cryptography.go @@ -41,6 +41,13 @@ var ( // and we can ensure there will only ever be a single recipient, // we can just ignore add/remove it at runtime to safe space. agePrefix = []byte("age-encryption.org/v1\n-> X25519 ") + ageHeader = []byte("age-encryption.org/v1\n") + + keyExchangeReplay = &sync.Map{} + serverKeyPairMu sync.Mutex + serverKeyPair *AgeKeyPair + minisignKeyMu sync.Mutex + minisignKey *minisign.PrivateKey ) // deriveKeyFrom - Derives a key from input data using SHA256 @@ -121,7 +128,7 @@ func AgeEncrypt(recipientPublicKey string, plaintext []byte) ([]byte, error) { if err := stream.Close(); err != nil { return nil, err } - return buf.Bytes(), nil + return bytes.TrimPrefix(buf.Bytes(), agePrefix), nil } // AgeDecrypt - Decrypt using Curve 25519 + ChaCha20Poly1305 @@ -132,8 +139,12 @@ func AgeDecrypt(recipientPrivateKey string, ciphertext []byte) ([]byte, error) { return nil, err } - // 直接使用 ciphertext,Age 库会自动处理 grease recipients - buf := bytes.NewBuffer(ciphertext) + // Accept both trimmed payloads and full age headers. + payload := ciphertext + if !bytes.HasPrefix(ciphertext, ageHeader) { + payload = append(agePrefix, ciphertext...) + } + buf := bytes.NewBuffer(payload) stream, err := age.Decrypt(buf, identity) if err != nil { // 如果解密失败,尝试添加调试信息 @@ -150,12 +161,6 @@ func AgeDecrypt(recipientPrivateKey string, ciphertext []byte) ([]byte, error) { // AgeKeyPairFromImplant - Decrypt the session key from an implant func AgeKeyExFromImplant(serverPrivateKey string, implantPrivateKey string, ciphertext []byte) ([]byte, error) { - // TODO - Store the hash of the implant's key exchange to prevent replay attacks - // Check for replay attacks - //if err := db.CheckKeyExReplay(ciphertext); err != nil { - // return nil, ErrDecryptFailed - //} - // Decrypt the message plaintext, err := AgeDecrypt(serverPrivateKey, ciphertext) if err != nil { @@ -176,6 +181,9 @@ func AgeKeyExFromImplant(serverPrivateKey string, implantPrivateKey string, ciph if !hmac.Equal(mac.Sum(nil), plaintext[:sha256Size]) { return nil, ErrDecryptFailed } + if err := recordKeyExchange(ciphertext); err != nil { + return nil, err + } return plaintext[sha256Size:], nil } @@ -272,23 +280,28 @@ func serverSignRawBuf(buf []byte) []byte { // AgeServerKeyPair - Get teh server's ECC key pair func AgeServerKeyPair() *AgeKeyPair { + serverKeyPairMu.Lock() + defer serverKeyPairMu.Unlock() + if serverKeyPair != nil { + return serverKeyPair + } // TODO - get key value from db //data, err := db.GetKeyValue(serverAgeKeyPairKey) // test data, err := json.Marshal(&AgeKeyPair{}) - //if err == db.ErrRecordNotFound { - // keyPair, err := generateServerKeyPair() - // if err != nil { - // panic(err) - // } - // return keyPair - //} keyPair := &AgeKeyPair{} - err = json.Unmarshal([]byte(data), keyPair) + if err == nil { + if err := json.Unmarshal([]byte(data), keyPair); err == nil && keyPair.Public != "" && keyPair.Private != "" { + serverKeyPair = keyPair + return serverKeyPair + } + } + keyPair, err = generateServerKeyPair() if err != nil { panic(err) } - return keyPair + serverKeyPair = keyPair + return serverKeyPair } func generateServerKeyPair() (*AgeKeyPair, error) { @@ -331,28 +344,45 @@ func MinisignServerSign(message []byte) string { // MinisignServerPrivateKey - Get the server's minisign key pair func MinisignServerPrivateKey() *minisign.PrivateKey { + minisignKeyMu.Lock() + defer minisignKeyMu.Unlock() + if minisignKey != nil { + return minisignKey + } // TODO - get key value from db - // test - data, err := json.Marshal(&AgeKeyPair{}) //data, err := db.GetKeyValue(serverMinisignPrivateKey) - //if err == db.ErrRecordNotFound { - // privateKey, err := generateServerMinisignPrivateKey() - // if err != nil { - // panic(err) - // } - // return privateKey - //} + // test + data, err := json.Marshal(&minisignPrivateKey{}) privateKey := &minisignPrivateKey{} - err = json.Unmarshal([]byte(data), privateKey) + if err == nil { + if err := json.Unmarshal([]byte(data), privateKey); err == nil && len(privateKey.PrivateKey) == ed25519.PrivateKeySize { + rawBytes := [ed25519.PrivateKeySize]byte{} + copy(rawBytes[:], privateKey.PrivateKey) + minisignKey = &minisign.PrivateKey{ + RawID: privateKey.ID, + RawBytes: rawBytes, + } + return minisignKey + } + } + privateKeyValue, err := generateServerMinisignPrivateKey() if err != nil { panic(err) } - rawBytes := [ed25519.PrivateKeySize]byte{} - copy(rawBytes[:], privateKey.PrivateKey) - return &minisign.PrivateKey{ - RawID: privateKey.ID, - RawBytes: rawBytes, + minisignKey = privateKeyValue + return minisignKey +} + +func recordKeyExchange(ciphertext []byte) error { + if len(ciphertext) == 0 { + return ErrDecryptFailed + } + digest := sha256.Sum256(ciphertext) + key := base64.RawStdEncoding.EncodeToString(digest[:]) + if _, ok := keyExchangeReplay.LoadOrStore(key, true); ok { + return ErrReplayAttack } + return nil } func generateServerMinisignPrivateKey() (*minisign.PrivateKey, error) { diff --git a/helper/cryptography/cryptography_test.go b/helper/cryptography/cryptography_test.go index 7cbcabde..53194341 100644 --- a/helper/cryptography/cryptography_test.go +++ b/helper/cryptography/cryptography_test.go @@ -3,7 +3,6 @@ package cryptography import ( "bytes" "crypto/rand" - "fmt" insecureRand "math/rand" "os" "sync" @@ -71,18 +70,20 @@ func TestAgeEncrypt(t *testing.T) { if err != nil { t.Fatal(err) } - fmt.Println(encrypted) - if !bytes.Equal([]byte(data), encrypted) { - t.Fatalf("Sample does not match decrypted data") + if bytes.Equal([]byte(data), encrypted) { + t.Fatalf("Ciphertext should not match plaintext") } } func TestAgeDecrypt(t *testing.T) { data := []byte{97, 103, 101, 45, 101, 110, 99, 114, 121, 112, 116, 105, 111, 110, 46, 111, 114, 103, 47, 118, 49, 10, 45, 62, 32, 88, 50, 53, 53, 49, 57, 32, 112, 115, 88, 48, 103, 104, 65, 84, 68, 120, 77, 111, 97, 84, 87, 77, 48, 47, 83, 117, 119, 50, 80, 107, 114, 52, 66, 43, 88, 105, 89, 75, 54, 112, 81, 122, 112, 43, 86, 104, 116, 103, 85, 10, 51, 82, 89, 50, 54, 116, 119, 70, 111, 108, 101, 70, 121, 66, 110, 57, 66, 101, 47, 121, 69, 79, 102, 99, 119, 76, 56, 107, 111, 115, 57, 55, 52, 115, 117, 110, 52, 56, 108, 48, 119, 69, 69, 10, 45, 62, 32, 66, 45, 103, 114, 101, 97, 115, 101, 32, 116, 61, 63, 42, 123, 75, 42, 32, 44, 47, 10, 66, 103, 111, 85, 119, 76, 83, 69, 120, 74, 120, 74, 87, 85, 109, 71, 118, 53, 73, 51, 120, 70, 121, 76, 43, 113, 52, 57, 97, 117, 50, 86, 74, 118, 108, 89, 47, 75, 110, 98, 66, 65, 49, 108, 72, 56, 48, 48, 52, 112, 98, 89, 47, 71, 69, 99, 89, 53, 52, 10, 45, 45, 45, 32, 56, 106, 115, 65, 101, 57, 97, 69, 108, 110, 116, 50, 109, 67, 99, 103, 122, 82, 48, 113, 53, 116, 55, 118, 57, 90, 86, 98, 112, 90, 85, 85, 83, 77, 71, 55, 89, 50, 79, 86, 104, 88, 81, 10, 226, 0, 72, 213, 103, 70, 169, 21, 148, 223, 128, 36, 70, 193, 95, 18, 97, 75, 179, 247, 222, 134, 200, 37, 24, 71, 167, 217, 5, 2, 143, 49, 50, 111, 245, 43, 73, 220, 140, 30, 133, 253, 34, 169, 28, 42, 179, 41, 170, 121, 110, 133, 51, 13, 184, 144, 192, 157, 152, 232, 20, 247, 130, 113, 201, 129, 233, 236, 222, 218, 132, 55, 199, 115, 246, 2, 208, 37, 248, 92, 110, 250, 188, 82, 162, 169, 104, 254, 34, 150, 212, 237, 208, 206, 202, 69, 32, 21, 74, 112, 195, 59, 0, 161, 192, 219, 139, 233, 197, 157, 177, 174, 7, 84, 168, 28, 125, 18, 148, 94, 225, 173, 98, 197, 239, 250, 240, 252, 1, 139, 146, 64, 22, 247, 199, 12, 237, 63, 195, 64, 157, 168, 82, 35, 64, 253, 114, 176, 11, 216, 112, 187, 212, 217, 28, 249, 67, 33, 131, 22, 87, 246, 79, 52, 91, 107, 143, 210, 77, 150, 104, 48, 7, 86, 165, 103, 13, 188, 228, 193, 194, 246, 184, 85, 121, 73, 54, 177, 66, 145, 103, 47, 96, 134, 133, 85, 187, 66, 123, 141, 198, 182, 49, 195, 73, 71, 29, 152, 166, 176, 69, 124, 177, 249, 0, 242, 169, 169, 151, 64, 188, 45, 45, 109, 252, 215, 94, 188, 112, 245, 5, 182, 50, 42, 203, 55, 133, 166, 160, 209, 159, 127, 167, 132, 222, 84, 108, 108, 19, 237, 154, 20, 109, 118, 175, 120, 75, 216, 206, 41, 246, 68, 110, 190, 132, 138, 151, 202, 203, 118, 232, 245, 158, 57, 159, 191, 188, 94, 173, 76, 214, 55, 75, 62, 94, 66, 185, 3, 42, 193, 217, 142, 136, 219, 175, 116, 107, 148, 157, 165, 210, 216, 71, 206, 237, 83, 106, 236, 52, 216, 124, 216, 13, 168, 53, 137, 180, 197, 156, 55, 156, 185, 70, 189, 47, 71, 160, 204, 158, 49, 16, 238, 127, 191, 31, 252, 229, 210, 227, 7, 151, 157, 146, 168, 115, 56, 223, 6, 253, 44, 170, 49, 236, 217, 55, 187, 248, 224, 222, 162, 181, 46, 225, 189, 197, 98, 251, 135, 185, 180, 138, 71, 218, 247, 96, 71, 91, 158, 186, 158, 86, 229, 226, 82, 3, 5, 237, 177, 176, 132, 17, 97, 227, 49, 217, 7, 195, 149, 130, 114, 36, 76, 64, 134, 254, 21, 116, 249, 103, 250, 111, 154, 249, 176, 209, 62, 65, 254, 216, 50, 113, 61, 53, 43, 36, 224, 244, 101, 181, 186, 198, 27, 74, 63, 146, 119, 108, 98, 236, 16, 156, 44, 60, 132, 173, 82, 31, 205, 167, 186, 249, 2, 123, 68, 86, 94, 80, 112, 165, 116, 76, 87, 25, 116, 2, 250, 212, 231, 254, 14, 130, 18, 175, 10, 198, 204, 178, 73, 68, 214, 6, 30, 16, 251, 243, 199, 47, 125, 212, 110, 36, 80, 5, 42, 253, 33, 27, 179, 50, 53, 130, 152, 75, 0, 79, 84, 160, 179, 238, 179, 203, 248, 183, 103, 83, 53, 18, 181, 80, 120, 171, 110, 142, 68, 58, 52, 220, 163, 44, 205, 124, 215, 86, 101, 6, 83, 177, 250, 183, 115, 213, 236, 226, 185, 143, 251, 73, 71, 117, 34, 57, 122, 236, 150, 230, 40, 219, 122, 237, 35, 116, 7, 88, 190, 205, 124, 42, 147, 135, 252, 194, 156, 188, 228, 102, 238, 162, 127, 12, 204, 8, 56, 119, 201, 158, 225, 15, 140, 149, 187, 207, 64, 210, 35, 96, 18, 165, 22, 54, 170, 199, 51, 49, 154, 215, 220, 3, 153, 109, 91, 145, 237, 136, 74, 12, 207, 195, 25, 152, 108, 175, 9, 185, 194, 50, 117, 31, 181, 79, 77, 45, 147, 39, 80, 49, 80, 153, 118, 42, 199, 74, 207, 111, 0, 107, 14, 12, 171, 240, 186, 52, 73, 25, 133, 5, 91, 165, 44, 207, 37, 142, 177, 104, 23, 71, 234, 80, 110, 254, 110, 199, 162, 204, 194, 193, 28, 149, 222, 47, 26, 204, 186, 192, 23, 204, 166, 194, 14, 58, 20, 102, 233, 123, 128, 205, 122, 206, 25, 96, 254, 101, 55, 83, 113, 117, 77, 207, 34, 166, 231, 253, 191, 218, 177, 24, 227, 92, 9, 166, 228, 217, 238, 7, 66, 65, 218, 202, 91, 225, 203, 183, 29, 87, 168, 76, 255, 186, 204, 199, 245, 85, 90, 149, 38, 208, 70, 31, 28, 202, 92, 7, 106, 158, 50, 186, 23, 179, 29, 85, 234, 104, 245, 21, 186, 167, 37, 50, 10, 184, 119, 246, 96, 62, 201, 43, 125, 128, 239, 79, 163, 5, 116, 45, 149, 27, 147, 181, 121, 243, 143, 31, 193, 21, 91, 5, 107, 179, 114, 159, 161, 66, 47, 52, 24, 103, 249, 242, 140, 12, 17, 96, 8, 116, 222, 56, 117, 126, 83, 184, 22, 186, 190, 175, 226, 160, 97, 18, 222, 193, 84, 245, 29, 195, 81, 228, 140, 223, 123, 218, 124, 245, 214, 6, 131, 253, 194, 134, 169, 45, 4, 158, 192, 175, 71, 205, 207, 31, 32, 141, 53, 117, 170, 218, 15, 72, 102, 211, 105} - _, err := AgeDecrypt("AGE-SECRET-KEY-1G0VT6PZP0P3CHK9HR0W8J7EF04DWP9TWH07MR27CCFVXR8HDJJTQU2DFRN", data) - if err == nil { + plaintext, err := AgeDecrypt("AGE-SECRET-KEY-1G0VT6PZP0P3CHK9HR0W8J7EF04DWP9TWH07MR27CCFVXR8HDJJTQU2DFRN", data) + if err != nil { t.Fatal(err) } + if len(plaintext) == 0 { + t.Fatal("decrypted plaintext is empty") + } } func TestAgeTamperEncryptDecrypt(t *testing.T) { diff --git a/helper/intermediate/function.go b/helper/intermediate/function.go index 9515e649..90235fef 100644 --- a/helper/intermediate/function.go +++ b/helper/intermediate/function.go @@ -3,6 +3,7 @@ package intermediate import ( "errors" "fmt" + "github.com/chainreactors/IoM-go/proto/client/clientpb" "github.com/chainreactors/logs" "github.com/chainreactors/mals" @@ -16,6 +17,7 @@ var ( WarnReturnMismatch = errors.New("return values mismatch") ) + type InternalFunc struct { *mals.MalFunction FinishCallback ImplantCallback // implant callback diff --git a/server/build/srdi.go b/server/build/srdi.go index cbb3b07a..a2b3e853 100644 --- a/server/build/srdi.go +++ b/server/build/srdi.go @@ -48,7 +48,7 @@ func ObjcopyPulse(path, platform, arch string) ([]byte, error) { } if err != nil { - return nil, fmt.Errorf("objcopy failed to extract shellcode %s") + return nil, fmt.Errorf("objcopy failed to extract shellcode: %w", err) } // Read the extracted binary shellcode diff --git a/server/config.yaml b/server/config.yaml index a16fb0c8..5e3fb0ac 100644 --- a/server/config.yaml +++ b/server/config.yaml @@ -86,6 +86,6 @@ server: webhook_url: null saas: enable: true - token: null url: https://build.chainreactors.red + token: YOUR_TOKEN_HERE diff --git a/server/internal/certutils/ca.go b/server/internal/certutils/ca.go index b3819146..222ebadb 100644 --- a/server/internal/certutils/ca.go +++ b/server/internal/certutils/ca.go @@ -78,8 +78,8 @@ func SaveCertificateAuthority(caType int, cert []byte, key []byte) { // CAs get written to the filesystem since we control the names and makes them // easier to move around/backup - certFilePath := filepath.Join(storageDir, fmt.Sprintf("%s-ca-cert.pem", caType)) - keyFilePath := filepath.Join(storageDir, fmt.Sprintf("%s-ca-key.pem", caType)) + certFilePath := filepath.Join(storageDir, fmt.Sprintf("%d-ca-cert.pem", caType)) + keyFilePath := filepath.Join(storageDir, fmt.Sprintf("%d-ca-key.pem", caType)) err := ioutil.WriteFile(certFilePath, cert, 0600) if err != nil { diff --git a/server/rpc/rpc-certificate.go b/server/rpc/rpc-certificate.go index f5740e87..f9e09224 100644 --- a/server/rpc/rpc-certificate.go +++ b/server/rpc/rpc-certificate.go @@ -3,6 +3,7 @@ package rpc import ( "context" "fmt" + "strings" "github.com/chainreactors/IoM-go/consts" "github.com/chainreactors/IoM-go/proto/client/clientpb" @@ -22,23 +23,45 @@ func (rpc *Server) GenerateSelfCert(ctx context.Context, req *clientpb.Pipeline) return nil, fmt.Errorf("pipeline %s tls config is nil", req.Name) } - if !req.Tls.Enable { - return &clientpb.Empty{}, nil + pipelineName := strings.TrimSpace(req.Name) + attachToPipeline := pipelineName != "" + // Standalone certificate management: allow generating/importing certs without binding to a pipeline. + if !attachToPipeline { + if req.Tls.Cert != nil && req.Tls.Cert.Cert != "" { + certModel, err := db.SaveCertFromTLS(req.Tls, "") + if err != nil { + return nil, err + } + return rpc.publishCertEvent(certModel) + } + + tls, err := certutils.GenerateSelfTLS("", req.Tls.CertSubject) + if err != nil { + return nil, err + } + req.Tls = tls + + certModel, err := db.SaveCertFromTLS(req.Tls, "") + if err != nil { + return nil, err + } + return rpc.publishCertEvent(certModel) } - if req.Name == "" { - return nil, fmt.Errorf("pipeline name is required to generate certificate") + // Pipeline-bound certificate generation: only act when TLS is enabled. + if !req.Tls.Enable { + return &clientpb.Empty{}, nil } if req.Tls.Cert != nil && req.Tls.Cert.Cert != "" { - certModel, err := db.SaveCertFromTLS(req.Tls, req.Name) + certModel, err := db.SaveCertFromTLS(req.Tls, pipelineName) if err != nil { return nil, err } return rpc.publishCertEvent(certModel) } - certModel, err := db.FindPipelineCert(req.Name, req.ListenerId) + certModel, err := db.FindPipelineCert(pipelineName, req.ListenerId) if err != nil { return nil, err } @@ -53,7 +76,7 @@ func (rpc *Server) GenerateSelfCert(ctx context.Context, req *clientpb.Pipeline) } req.Tls = tls - certModel, err = db.SaveCertFromTLS(req.Tls, req.Name) + certModel, err = db.SaveCertFromTLS(req.Tls, pipelineName) if err != nil { return nil, err } diff --git a/server/rpc/rpc-file.go b/server/rpc/rpc-file.go index 5a8577b8..08662ee6 100644 --- a/server/rpc/rpc-file.go +++ b/server/rpc/rpc-file.go @@ -255,7 +255,7 @@ func (rpc *Server) Download(ctx context.Context, req *implantpb.DownloadRequest) chunkFile := filepath.Join(tempDir, fmt.Sprintf("%d.chunk", downloadResp.Cur)) err = os.WriteFile(chunkFile, downloadResp.Content, 0644) if err != nil { - logs.Log.Errorf("failed to save chunk %d: %w", downloadResp.Cur, err) + logs.Log.Errorf("failed to save chunk %d: %v", downloadResp.Cur, err) return } if checksum, _ := fileutils.CalculateSHA256Checksum(chunkFile); checksum != downloadResp.Checksum {