Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Models, Files and Fine-Tuning APIs #20

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,32 @@
# If you prefer the allow list template instead of the deny list, see community template:
# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore
#
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib

# Test binary, built with `go test -c`
*.test

# Output of the go coverage tool, specifically when used with LiteIDE
*.out

# Dependency directories (remove the comment below to include it)
# vendor/

# Go workspace file
go.work

# Env
.env
.env.*

# IDEs
.idea/
.vscode/

# macOS
.DS_Store
139 changes: 139 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package gpt3

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"time"
)

const (
DEFAULT_BASE_URL = "https://api.openai.com/v1"
DEFAULT_USER_AGENT = "gpt3-go"
DEFAULT_TIMEOUT = 30
)

var dataPrefix = []byte("data: ")
var streamTerminationPrefix = []byte("[DONE]")

type Client interface {
Models(ctx context.Context) (*ModelsResponse, error)
Model(ctx context.Context, model string) (*ModelObject, error)
Completion(ctx context.Context, request CompletionRequest) (*CompletionResponse, error)
CompletionStream(ctx context.Context, request CompletionRequest, onData func(*CompletionResponse)) error
Edits(ctx context.Context, request EditsRequest) (*EditsResponse, error)
Embeddings(ctx context.Context, request EmbeddingsRequest) (*EmbeddingsResponse, error)
Files(ctx context.Context) (*FilesResponse, error)
UploadFile(ctx context.Context, request UploadFileRequest) (*FileObject, error)
DeleteFile(ctx context.Context, fileID string) (*DeleteFileResponse, error)
File(ctx context.Context, fileID string) (*FileObject, error)
FileContent(ctx context.Context, fileID string) ([]byte, error)
CreateFineTune(ctx context.Context, request CreateFineTuneRequest) (*FineTuneObject, error)
FineTunes(ctx context.Context) (*FineTunesResponse, error)
FineTune(ctx context.Context, fineTuneID string) (*FineTuneObject, error)
CancelFineTune(ctx context.Context, fineTuneID string) (*FineTuneObject, error)
FineTuneEvents(ctx context.Context, request FineTuneEventsRequest) (*FineTuneEventsResponse, error)
FineTuneStreamEvents(ctx context.Context, request FineTuneEventsRequest, onData func(*FineTuneEvent)) error
DeleteFineTuneModel(ctx context.Context, modelID string) (*DeleteFineTuneModelResponse, error)
}

type client struct {
baseURL string
apiKey string
orgID string
userAgent string
httpClient *http.Client
defaultModel string
}

func NewClient(apiKey string, options ...ClientOption) (Client, error) {
c := &client{
baseURL: DEFAULT_BASE_URL,
apiKey: apiKey,
orgID: "",
userAgent: DEFAULT_USER_AGENT,
httpClient: &http.Client{Timeout: time.Duration(DEFAULT_TIMEOUT) * time.Second},
defaultModel: DavinciModel,
}

for _, option := range options {
if err := option(c); err != nil {
return nil, err
}
}

return c, nil
}

func (c *client) newRequest(ctx context.Context, method, path string, payload interface{}) (*http.Request, error) {
bodyReader, err := jsonBodyReader(payload)
if err != nil {
return nil, err
}
url := c.baseURL + path
req, err := http.NewRequestWithContext(ctx, method, url, bodyReader)
if err != nil {
return nil, err
}
req.Header.Set("Content-type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
req.Header.Set("User-Agent", c.userAgent)
if len(c.orgID) > 0 {
req.Header.Set("OpenAI-Organization", c.orgID)
}
return req, nil
}

func (c *client) performRequest(req *http.Request) (*http.Response, error) {
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
return resp, checkForSuccess(resp)
}

func checkForSuccess(resp *http.Response) error {
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return nil
}
defer resp.Body.Close()
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read from body: %w", err)
}
var result APIErrorResponse
if err := json.Unmarshal(data, &result); err != nil {
// if we can't decode the json error then create an unexpected error
apiError := APIError{
StatusCode: resp.StatusCode,
Type: "Unexpected",
Message: string(data),
}
return apiError
}
result.Error.StatusCode = resp.StatusCode
return result.Error
}

func getResponseObject(rsp *http.Response, v interface{}) error {
defer rsp.Body.Close()
if err := json.NewDecoder(rsp.Body).Decode(v); err != nil {
return fmt.Errorf("invalid json response: %w", err)
}
return nil
}

func jsonBodyReader(body interface{}) (io.Reader, error) {
if body == nil {
return bytes.NewBuffer(nil), nil
}
raw, err := json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("failed encoding json: %w", err)
}
return bytes.NewBuffer(raw), nil
}
10 changes: 5 additions & 5 deletions client_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@ import (
// ClientOption are options that can be passed when creating a new client
type ClientOption func(*client) error

// WithOrg is a client option that allows you to override the organization ID
// WithOrg is a client option that allows you to set the organization ID
func WithOrg(id string) ClientOption {
return func(c *client) error {
c.idOrg = id
c.orgID = id
return nil
}
}

// WithDefaultEngine is a client option that allows you to override the default engine of the client
func WithDefaultEngine(engine string) ClientOption {
// WithDefaultModel is a client option that allows you to override the default model of the client
func WithDefaultModel(model string) ClientOption {
return func(c *client) error {
c.defaultEngine = engine
c.defaultModel = model
return nil
}
}
Expand Down
46 changes: 46 additions & 0 deletions client_options_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package gpt3_test

import (
"net/http"
"testing"

"github.com/PullRequestInc/go-gpt3"
"github.com/stretchr/testify/assert"
)

func TestClientWithOrg(t *testing.T) {
client, err := gpt3.NewClient("test-key", gpt3.WithOrg("test-org"))
assert.Nil(t, err)
assert.NotNil(t, client)
}

func TestClientWithDefaultModel(t *testing.T) {
client, err := gpt3.NewClient("test-key", gpt3.WithDefaultModel("test-model"))
assert.Nil(t, err)
assert.NotNil(t, client)
}

func TestClientWithUserAgent(t *testing.T) {
client, err := gpt3.NewClient("test-key", gpt3.WithUserAgent("test-agent"))
assert.Nil(t, err)
assert.NotNil(t, client)
}

func TestClientWithBaseURL(t *testing.T) {
client, err := gpt3.NewClient("test-key", gpt3.WithBaseURL("test-url"))
assert.Nil(t, err)
assert.NotNil(t, client)
}

func TestClientWithHTTPClient(t *testing.T) {
httpClient := &http.Client{}
client, err := gpt3.NewClient("test-key", gpt3.WithHTTPClient(httpClient))
assert.Nil(t, err)
assert.NotNil(t, client)
}

func TestClientWithTimeout(t *testing.T) {
client, err := gpt3.NewClient("test-key", gpt3.WithTimeout(10))
assert.Nil(t, err)
assert.NotNil(t, client)
}
14 changes: 14 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package gpt3_test

import (
"testing"

"github.com/PullRequestInc/go-gpt3"
"github.com/stretchr/testify/assert"
)

func TestInitNewClient(t *testing.T) {
client, err := gpt3.NewClient("test-key")
assert.Nil(t, err)
assert.NotNil(t, client)
}
10 changes: 8 additions & 2 deletions cmd/errors/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,21 @@ import (
)

func main() {
godotenv.Load()
err := godotenv.Load()
if err != nil {
log.Fatalln(err)
}

apiKey := os.Getenv("API_KEY")
if apiKey == "" {
log.Fatalln("Missing API KEY")
}

ctx := context.Background()
client := gpt3.NewClient(apiKey)
client, err := gpt3.NewClient(apiKey)
if err != nil {
log.Fatalln(err)
}

resp, err := client.Completion(ctx, gpt3.CompletionRequest{
Prompt: []string{
Expand Down
5 changes: 4 additions & 1 deletion cmd/test/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ func main() {
}

ctx := context.Background()
client := gpt3.NewClient(apiKey)
client, err := gpt3.NewClient(apiKey)
if err != nil {
log.Fatalln(err)
}

resp, err := client.Completion(ctx, gpt3.CompletionRequest{
Prompt: []string{
Expand Down
Loading