11package main
22
33import (
4+ "errors"
45 "flag"
56 "fmt"
67 "io"
8+ "log"
79 "os"
810 "path/filepath"
911 "strings"
@@ -12,6 +14,7 @@ import (
1214 "github.com/atotto/clipboard"
1315 "github.com/briandowns/spinner"
1416 "github.com/chzyer/readline"
17+ "github.com/sashabaranov/go-openai"
1518
1619 "github.com/shellfly/aoi/pkg/chatgpt"
1720 "github.com/shellfly/aoi/pkg/color"
@@ -24,21 +27,45 @@ the character laughing man who named Aoi, so you named yourself Aoi. Respond
2427like we are good friend.
2528`
2629
27- func main () {
28- startUp ()
29-
30+ func InitClient () (* openai.Client , string , error ) {
3031 var model , openaiAPIKey , openaiAPIBaseUrl string
32+ var azureDeployment string
3133 flag .StringVar (& openaiAPIBaseUrl , "openai_api_base_url" , os .Getenv ("OPENAI_API_BASE_URL" ), "OpenAI API Base Url, default: https://api.openai.com" )
3234 flag .StringVar (& openaiAPIKey , "openai_api_key" , os .Getenv ("OPENAI_API_KEY" ), "OpenAI API key" )
3335 flag .StringVar (& model , "model" , "gpt-3.5-turbo" , "model to use" )
36+ flag .StringVar (& azureDeployment , "azure.deployment" , "" , "azure deployment name of the model" )
3437 flag .Parse ()
3538
36- // Create an AI
37- ai , err := chatgpt .NewAI (openaiAPIBaseUrl , openaiAPIKey , model )
39+ if openaiAPIKey == "" {
40+ return nil , "" , errors .New ("Please set the OPENAI_API_KEY environment variable" )
41+ }
42+
43+ var config openai.ClientConfig
44+ if azureDeployment != "" {
45+ if openaiAPIBaseUrl == "" {
46+ return nil , "" , errors .New ("Please set the OPENAI_API_BASE_URL to your azure endpoint" )
47+ }
48+ config = openai .DefaultAzureConfig (openaiAPIKey , openaiAPIBaseUrl )
49+ config .AzureModelMapperFunc = func (model string ) string {
50+ return azureDeployment
51+ }
52+ } else {
53+ config = openai .DefaultConfig (openaiAPIKey )
54+ if openaiAPIBaseUrl != "" {
55+ config .BaseURL = openaiAPIBaseUrl
56+ }
57+ }
58+ client := openai .NewClientWithConfig (config )
59+ return client , model , nil
60+ }
61+
62+ func main () {
63+ startUp ()
64+ client , model , err := InitClient ()
3865 if err != nil {
39- fmt .Println ("create ai error: " , err )
40- return
66+ log .Fatal (err )
4167 }
68+ ai := chatgpt .NewAI (client , model )
4269 ai .SetSystem (system )
4370
4471 configDir := makeDir (".aoi" )
@@ -90,6 +117,7 @@ func main() {
90117 // If previous is finished try to create a new one, otherwise continue
91118 // to reuse it for prompts
92119 if cmd .IsFinished () {
120+ ai .Reset ()
93121 cmd , prompts = command .Parse (input )
94122 rl .SetPrompt (color .Yellow (cmd .Prompt (userPrompt )))
95123 } else {
0 commit comments