Skip to content

Commit 885026b

Browse files
authored
Merge pull request #3 from shellfly/feat-azure
feat: add azure openai support
2 parents ade11ac + 4bfc79f commit 885026b

File tree

5 files changed

+48
-24
lines changed

5 files changed

+48
-24
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ require (
88
github.com/chzyer/readline v1.5.1
99
github.com/kevinburke/ssh_config v1.2.0
1010
github.com/rest-go/rest v0.1.3
11-
github.com/sashabaranov/go-openai v1.5.2
11+
github.com/sashabaranov/go-openai v1.10.1
1212
github.com/stretchr/testify v1.8.2
1313
golang.org/x/crypto v0.7.0
1414
)

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OK
107107
github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc=
108108
github.com/sashabaranov/go-openai v1.5.2 h1:Gtn5HZEL25//rDDLEX+Anw5FI8TUC6gqIeM9BDBOO18=
109109
github.com/sashabaranov/go-openai v1.5.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
110+
github.com/sashabaranov/go-openai v1.10.1 h1:6WyHJaNzF266VaEEuW6R4YW+Ei0wpMnqRYPGK7fhuhQ=
111+
github.com/sashabaranov/go-openai v1.10.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
110112
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
111113
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4=
112114
github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=

main.go

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package main
22

33
import (
4+
"errors"
45
"flag"
56
"fmt"
67
"io"
8+
"log"
79
"os"
810
"path/filepath"
911
"strings"
@@ -12,6 +14,7 @@ import (
1214
"github.com/atotto/clipboard"
1315
"github.com/briandowns/spinner"
1416
"github.com/chzyer/readline"
17+
"github.com/sashabaranov/go-openai"
1518

1619
"github.com/shellfly/aoi/pkg/chatgpt"
1720
"github.com/shellfly/aoi/pkg/color"
@@ -24,21 +27,45 @@ the character laughing man who named Aoi, so you named yourself Aoi. Respond
2427
like we are good friend.
2528
`
2629

27-
func main() {
28-
startUp()
29-
30+
func InitClient() (*openai.Client, string, error) {
3031
var model, openaiAPIKey, openaiAPIBaseUrl string
32+
var azureDeployment string
3133
flag.StringVar(&openaiAPIBaseUrl, "openai_api_base_url", os.Getenv("OPENAI_API_BASE_URL"), "OpenAI API Base Url, default: https://api.openai.com")
3234
flag.StringVar(&openaiAPIKey, "openai_api_key", os.Getenv("OPENAI_API_KEY"), "OpenAI API key")
3335
flag.StringVar(&model, "model", "gpt-3.5-turbo", "model to use")
36+
flag.StringVar(&azureDeployment, "azure.deployment", "", "azure deployment name of the model")
3437
flag.Parse()
3538

36-
// Create an AI
37-
ai, err := chatgpt.NewAI(openaiAPIBaseUrl, openaiAPIKey, model)
39+
if openaiAPIKey == "" {
40+
return nil, "", errors.New("Please set the OPENAI_API_KEY environment variable")
41+
}
42+
43+
var config openai.ClientConfig
44+
if azureDeployment != "" {
45+
if openaiAPIBaseUrl == "" {
46+
return nil, "", errors.New("Please set the OPENAI_API_BASE_URL to your azure endpoint")
47+
}
48+
config = openai.DefaultAzureConfig(openaiAPIKey, openaiAPIBaseUrl)
49+
config.AzureModelMapperFunc = func(model string) string {
50+
return azureDeployment
51+
}
52+
} else {
53+
config = openai.DefaultConfig(openaiAPIKey)
54+
if openaiAPIBaseUrl != "" {
55+
config.BaseURL = openaiAPIBaseUrl
56+
}
57+
}
58+
client := openai.NewClientWithConfig(config)
59+
return client, model, nil
60+
}
61+
62+
func main() {
63+
startUp()
64+
client, model, err := InitClient()
3865
if err != nil {
39-
fmt.Println("create ai error: ", err)
40-
return
66+
log.Fatal(err)
4167
}
68+
ai := chatgpt.NewAI(client, model)
4269
ai.SetSystem(system)
4370

4471
configDir := makeDir(".aoi")
@@ -90,6 +117,7 @@ func main() {
90117
// If previous is finished try to create a new one, otherwise continue
91118
// to reuse it for prompts
92119
if cmd.IsFinished() {
120+
ai.Reset()
93121
cmd, prompts = command.Parse(input)
94122
rl.SetPrompt(color.Yellow(cmd.Prompt(userPrompt)))
95123
} else {

pkg/chatgpt/ai.go

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package chatgpt
22

33
import (
44
"context"
5-
"errors"
65
"fmt"
76
"strings"
87
"time"
@@ -22,26 +21,17 @@ type AI struct {
2221
debug bool
2322
}
2423

25-
func NewAI(apiBaseUrl, apiKey, model string) (*AI, error) {
26-
if apiKey == "" {
27-
return nil, errors.New("Please set the OPENAI_API_KEY environment variable")
28-
}
29-
30-
// Create a new OpenAI API client with the provided API key
31-
config := openai.DefaultConfig(apiKey)
32-
if apiBaseUrl != "" {
33-
config.BaseURL = apiBaseUrl + "/v1"
34-
}
35-
client := openai.NewClientWithConfig(config)
24+
func NewAI(client *openai.Client, model string) *AI {
3625
messages := make([]openai.ChatCompletionMessage, 0, 2*MessageLimit)
3726
ai := &AI{
3827
client: client,
3928
model: model,
4029
messages: messages,
4130
debug: false,
4231
}
43-
return ai, nil
32+
return ai
4433
}
34+
4535
func (ai *AI) SetSystem(system string) {
4636
ai.system = system
4737
ai.messages = []openai.ChatCompletionMessage{NewMessage(openai.ChatMessageRoleSystem, system)}
@@ -70,7 +60,11 @@ func (ai *AI) Query(prompts []string) (string, error) {
7060
ai.limitTokens()
7161

7262
if ai.debug {
73-
fmt.Println(ai.messages)
63+
fmt.Println("---debug---")
64+
for _, msg := range ai.messages {
65+
fmt.Println(msg)
66+
}
67+
fmt.Println("---debug---")
7468
}
7569
// Set the request parameters for the completion API
7670
req := openai.ChatCompletionRequest{

pkg/chatgpt/ai_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ import (
99
)
1010

1111
func TestAI(t *testing.T) {
12-
ai, err := NewAI("https:...", "api key", "model")
13-
assert.Nil(t, err)
12+
client := openai.NewClient("api key")
13+
ai := NewAI(client, "model")
1414
t.Run("limit tokens", func(t *testing.T) {
1515
ai.messages = make([]openai.ChatCompletionMessage, MessageLimit+2)
1616
ai.messages[0] = NewMessage("system", "message")

0 commit comments

Comments
 (0)