Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Models, Files and Fine-Tuning APIs #20

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,32 @@
# If you prefer the allow list template instead of the deny list, see community template:
# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore
#
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib

# Test binary, built with `go test -c`
*.test

# Output of the go coverage tool, specifically when used with LiteIDE
*.out

# Dependency directories (remove the comment below to include it)
# vendor/

# Go workspace file
go.work

# Env
.env
.env.*

# IDEs
.idea/
.vscode/

# macOS
.DS_Store
143 changes: 143 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
package gpt3

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)

const (
DEFAULT_BASE_URL = "https://api.openai.com/v1"
DEFAULT_USER_AGENT = "gpt3-go"
DEFAULT_TIMEOUT = 30
)

var dataPrefix = []byte("data: ")
var streamTerminationPrefix = []byte("[DONE]")

type Client interface {
Models(ctx context.Context) (*ModelsResponse, error)
Model(ctx context.Context, model string) (*ModelObject, error)
Completion(ctx context.Context, request CompletionRequest) (*CompletionResponse, error)
CompletionStream(ctx context.Context, request CompletionRequest, onData func(*CompletionResponse)) error
Edits(ctx context.Context, request EditsRequest) (*EditsResponse, error)
Embeddings(ctx context.Context, request EmbeddingsRequest) (*EmbeddingsResponse, error)
Files(ctx context.Context) (*FilesResponse, error)
UploadFile(ctx context.Context, request UploadFileRequest) (*FileObject, error)
DeleteFile(ctx context.Context, fileID string) (*DeleteFileResponse, error)
File(ctx context.Context, fileID string) (*FileObject, error)
FileContent(ctx context.Context, fileID string) ([]byte, error)
CreateFineTune(ctx context.Context, request CreateFineTuneRequest) (*FineTuneObject, error)
ListFineTunes(ctx context.Context) (*ListFineTunesResponse, error)
FineTune(ctx context.Context, fineTuneID string) (*FineTuneObject, error)
CancelFineTune(ctx context.Context, fineTuneID string) (*FineTuneObject, error)
FineTuneEvents(ctx context.Context, request FineTuneEventsRequest) (*FineTuneEventsResponse, error)
FineTuneStreamEvents(ctx context.Context, request FineTuneEventsRequest, onData func(*FineTuneEvent)) error
DeleteFineTuneModel(ctx context.Context, modelID string) (*DeleteFineTuneModelResponse, error)

// Deprecated
CompletionWithEngine(ctx context.Context, engine string, request CompletionRequest) (*CompletionResponse, error)
CompletionStreamWithEngine(ctx context.Context, engine string, request CompletionRequest, onData func(*CompletionResponse)) error
}

type client struct {
baseURL string
apiKey string
orgID string
userAgent string
httpClient *http.Client
defaultModel string
}

func NewClient(apiKey string, options ...ClientOption) (Client, error) {
c := &client{
baseURL: DEFAULT_BASE_URL,
apiKey: apiKey,
orgID: "",
userAgent: DEFAULT_USER_AGENT,
httpClient: &http.Client{Timeout: time.Duration(DEFAULT_TIMEOUT) * time.Second},
defaultModel: DavinciModel,
}

for _, option := range options {
if err := option(c); err != nil {
return nil, err
}
}

return c, nil
}

func (c *client) newRequest(ctx context.Context, method, path string, payload interface{}) (*http.Request, error) {
bodyReader, err := jsonBodyReader(payload)
if err != nil {
return nil, err
}
url := c.baseURL + path
req, err := http.NewRequestWithContext(ctx, method, url, bodyReader)
if err != nil {
return nil, err
}
req.Header.Set("Content-type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
req.Header.Set("User-Agent", c.userAgent)
if len(c.orgID) > 0 {
req.Header.Set("OpenAI-Organization", c.orgID)
}
return req, nil
}

func (c *client) performRequest(req *http.Request) (*http.Response, error) {
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was testing out the changes on this branch locally and was running into errors, that I believe are from this line. I don't think we used to do this defer/body close from the performRequest happy path, so it was giving me an error when trying to run go run cmd/test/main.go

invalid json response: http2: response body closed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah let me fix that, I believe I was told to move it there, does this occur when there is no response body?

return resp, checkForSuccess(resp)
}

func checkForSuccess(resp *http.Response) error {
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return nil
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read from body: %w", err)
}
var result APIErrorResponse
if err := json.Unmarshal(data, &result); err != nil {
// if we can't decode the json error then create an unexpected error
apiError := APIError{
StatusCode: resp.StatusCode,
Type: "Unexpected",
Message: string(data),
}
return apiError
}
result.Error.StatusCode = resp.StatusCode
return result.Error
}

func getResponseObject(rsp *http.Response, v interface{}) error {
defer rsp.Body.Close()
if err := json.NewDecoder(rsp.Body).Decode(v); err != nil {
return fmt.Errorf("invalid json response: %w", err)
}
return nil
}

func jsonBodyReader(body interface{}) (io.Reader, error) {
if body == nil {
// the body is allowed to be nil so we return an empty buffer
return bytes.NewBuffer(nil), nil
}
raw, err := json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("failed encoding json: %w", err)
}
return bytes.NewBuffer(raw), nil
}
10 changes: 5 additions & 5 deletions client_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@ import (
// ClientOption are options that can be passed when creating a new client
type ClientOption func(*client) error

// WithOrg is a client option that allows you to override the organization ID
// WithOrg is a client option that allows you to set the organization ID
func WithOrg(id string) ClientOption {
return func(c *client) error {
c.idOrg = id
c.orgID = id
return nil
}
}

// WithDefaultEngine is a client option that allows you to override the default engine of the client
func WithDefaultEngine(engine string) ClientOption {
// WithDefaultModel is a client option that allows you to override the default model of the client
func WithDefaultModel(model string) ClientOption {
return func(c *client) error {
c.defaultEngine = engine
c.defaultModel = model
return nil
}
}
Expand Down
36 changes: 36 additions & 0 deletions client_options_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package gpt3_test

import (
"net/http"
"testing"

"github.com/PullRequestInc/go-gpt3"
"github.com/stretchr/testify/assert"
)

func TestClient(t *testing.T) {
testCases := []struct {
name string
options []gpt3.ClientOption
}{
{
name: "test-key",
options: []gpt3.ClientOption{
gpt3.WithOrg("test-org"),
gpt3.WithDefaultModel("test-model"),
gpt3.WithUserAgent("test-agent"),
gpt3.WithBaseURL("test-url"),
gpt3.WithHTTPClient(&http.Client{}),
gpt3.WithTimeout(10),
},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
client, err := gpt3.NewClient(tc.name, tc.options...)
assert.Nil(t, err)
assert.NotNil(t, client)
})
}
}
14 changes: 14 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package gpt3_test

import (
"testing"

"github.com/PullRequestInc/go-gpt3"
"github.com/stretchr/testify/assert"
)

func TestInitNewClient(t *testing.T) {
client, err := gpt3.NewClient("test-key")
assert.Nil(t, err)
assert.NotNil(t, client)
}
10 changes: 8 additions & 2 deletions cmd/errors/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,21 @@ import (
)

func main() {
godotenv.Load()
err := godotenv.Load()
if err != nil {
log.Fatalln(err)
}

apiKey := os.Getenv("API_KEY")
if apiKey == "" {
log.Fatalln("Missing API KEY")
}

ctx := context.Background()
client := gpt3.NewClient(apiKey)
client, err := gpt3.NewClient(apiKey)
if err != nil {
log.Fatalln(err)
}

resp, err := client.Completion(ctx, gpt3.CompletionRequest{
Prompt: []string{
Expand Down
5 changes: 4 additions & 1 deletion cmd/test/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ func main() {
}

ctx := context.Background()
client := gpt3.NewClient(apiKey)
client, err := gpt3.NewClient(apiKey)
if err != nil {
log.Fatalln(err)
}

resp, err := client.Completion(ctx, gpt3.CompletionRequest{
Prompt: []string{
Expand Down
Loading