Skip to content

Commit 341efdb

Browse files
authored
refactor(go)!: Updated all primitives to unified options (#2550)
1 parent 395de91 commit 341efdb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+906
-867
lines changed

go/ai/embedder.go

+20-74
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ package ai
1818

1919
import (
2020
"context"
21-
"errors"
21+
"fmt"
2222

2323
"github.com/firebase/genkit/go/core"
2424
"github.com/firebase/genkit/go/internal/atype"
@@ -33,29 +33,8 @@ type Embedder interface {
3333
Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error)
3434
}
3535

36-
// An embedderActionDef is used to convert a document to a
37-
// multidimensional vector.
38-
type embedderActionDef core.ActionDef[*EmbedRequest, *EmbedResponse, struct{}]
39-
40-
type embedderAction = core.ActionDef[*EmbedRequest, *EmbedResponse, struct{}]
41-
42-
// EmbedRequest is the data we pass to convert one or more documents
43-
// to a multidimensional vector.
44-
type EmbedRequest struct {
45-
Documents []*Document `json:"input"`
46-
Options any `json:"options,omitempty"`
47-
}
48-
49-
type EmbedResponse struct {
50-
// One embedding for each Document in the request, in the same order.
51-
Embeddings []*DocumentEmbedding `json:"embeddings"`
52-
}
53-
54-
// DocumentEmbedding holds emdedding information about a single document.
55-
type DocumentEmbedding struct {
56-
// The vector for the embedding.
57-
Embedding []float32 `json:"embedding"`
58-
}
36+
// An embedder is used to convert a document to a multidimensional vector.
37+
type embedder core.ActionDef[*EmbedRequest, *EmbedResponse, struct{}]
5938

6039
// DefineEmbedder registers the given embed function as an action, and returns an
6140
// [Embedder] that runs it.
@@ -64,12 +43,7 @@ func DefineEmbedder(
6443
provider, name string,
6544
embed func(context.Context, *EmbedRequest) (*EmbedResponse, error),
6645
) Embedder {
67-
return (*embedderActionDef)(core.DefineAction(r, provider, name, atype.Embedder, nil, embed))
68-
}
69-
70-
// IsDefinedEmbedder reports whether an embedder is defined.
71-
func IsDefinedEmbedder(r *registry.Registry, provider, name string) bool {
72-
return LookupEmbedder(r, provider, name) != nil
46+
return (*embedder)(core.DefineAction(r, provider, name, atype.Embedder, nil, embed))
7347
}
7448

7549
// LookupEmbedder looks up an [Embedder] registered by [DefineEmbedder].
@@ -79,61 +53,33 @@ func LookupEmbedder(r *registry.Registry, provider, name string) Embedder {
7953
if action == nil {
8054
return nil
8155
}
82-
return (*embedderActionDef)(action)
83-
}
8456

85-
// Embed runs the given [Embedder].
86-
func (e *embedderActionDef) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) {
87-
if e == nil {
88-
return nil, errors.New("Embed called on a nil Embedder; check that all embedders are defined")
89-
}
90-
a := (*core.ActionDef[*EmbedRequest, *EmbedResponse, struct{}])(e)
91-
return a.Run(ctx, req, nil)
57+
return (*embedder)(action)
9258
}
9359

94-
func (e *embedderActionDef) Name() string {
95-
return (*embedderAction)(e).Name()
60+
// Name returns the name of the embedder.
61+
func (e *embedder) Name() string {
62+
return (*core.ActionDef[*EmbedRequest, *EmbedResponse, struct{}])(e).Name()
9663
}
9764

98-
// EmbedOption configures params of the Embed call.
99-
type EmbedOption func(req *EmbedRequest) error
100-
101-
// WithEmbedOptions set embedder options on [EmbedRequest]
102-
func WithEmbedOptions(opts any) EmbedOption {
103-
return func(req *EmbedRequest) error {
104-
req.Options = opts
105-
return nil
106-
}
65+
// Embed runs the given [Embedder].
66+
func (e *embedder) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) {
67+
return (*core.ActionDef[*EmbedRequest, *EmbedResponse, struct{}])(e).Run(ctx, req, nil)
10768
}
10869

109-
// WithEmbedText adds simple text documents to [EmbedRequest]
110-
func WithEmbedText(text ...string) EmbedOption {
111-
return func(req *EmbedRequest) error {
112-
var docs []*Document
113-
for _, p := range text {
114-
docs = append(docs, DocumentFromText(p, nil))
70+
// Embed invokes the embedder with provided options.
71+
func Embed(ctx context.Context, e Embedder, opts ...EmbedderOption) (*EmbedResponse, error) {
72+
embedOpts := &embedderOptions{}
73+
for _, opt := range opts {
74+
if err := opt.applyEmbedder(embedOpts); err != nil {
75+
return nil, fmt.Errorf("ai.Embed: error applying options: %w", err)
11576
}
116-
req.Documents = append(req.Documents, docs...)
117-
return nil
11877
}
119-
}
12078

121-
// WithEmbedDocs adds documents to [EmbedRequest]
122-
func WithEmbedDocs(docs ...*Document) EmbedOption {
123-
return func(req *EmbedRequest) error {
124-
req.Documents = append(req.Documents, docs...)
125-
return nil
79+
req := &EmbedRequest{
80+
Input: embedOpts.Documents,
81+
Options: embedOpts.Config,
12682
}
127-
}
12883

129-
// Embed invokes the embedder with provided options.
130-
func Embed(ctx context.Context, e Embedder, opts ...EmbedOption) (*EmbedResponse, error) {
131-
req := &EmbedRequest{}
132-
for _, with := range opts {
133-
err := with(req)
134-
if err != nil {
135-
return nil, err
136-
}
137-
}
13884
return e.Embed(ctx, req)
13985
}

go/ai/evaluator.go

+27-62
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,7 @@ type Evaluator interface {
3838
Evaluate(ctx context.Context, req *EvaluatorRequest) (*EvaluatorResponse, error)
3939
}
4040

41-
type (
42-
evaluatorActionDef core.ActionDef[*EvaluatorRequest, *EvaluatorResponse, struct{}]
43-
44-
evaluatorAction = core.ActionDef[*EvaluatorRequest, *EvaluatorResponse, struct{}]
45-
)
41+
type evaluator core.ActionDef[*EvaluatorRequest, *EvaluatorResponse, struct{}]
4642

4743
// Example is a single example that requires evaluation
4844
type Example struct {
@@ -54,15 +50,12 @@ type Example struct {
5450
TraceIds []string `json:"traceIds,omitempty"`
5551
}
5652

57-
// Dataset is a collection of [Example]
58-
type Dataset = []Example
59-
6053
// EvaluatorRequest is the data we pass to evaluate a dataset.
6154
// The Options field is specific to the actual evaluator implementation.
6255
type EvaluatorRequest struct {
63-
Dataset *Dataset `json:"dataset"`
64-
EvaluationId string `json:"evalRunId"`
65-
Options any `json:"options,omitempty"`
56+
Dataset []*Example `json:"dataset"`
57+
EvaluationId string `json:"evalRunId"`
58+
Options any `json:"options,omitempty"`
6659
}
6760

6861
// ScoreStatus is an enum used to indicate if a Score has passed or failed. This
@@ -141,20 +134,18 @@ func DefineEvaluator(r *registry.Registry, provider, name string, options *Evalu
141134
metadataMap["evaluatorDisplayName"] = options.DisplayName
142135
metadataMap["evaluatorDefinition"] = options.Definition
143136

144-
actionDef := (*evaluatorActionDef)(core.DefineAction(r, provider, name, atype.Evaluator, map[string]any{"evaluator": metadataMap}, func(ctx context.Context, req *EvaluatorRequest) (output *EvaluatorResponse, err error) {
137+
actionDef := (*evaluator)(core.DefineAction(r, provider, name, atype.Evaluator, map[string]any{"evaluator": metadataMap}, func(ctx context.Context, req *EvaluatorRequest) (output *EvaluatorResponse, err error) {
145138
var evalResponses []EvaluationResult
146-
dataset := *req.Dataset
147-
for i := 0; i < len(dataset); i++ {
148-
datapoint := dataset[i]
139+
for _, datapoint := range req.Dataset {
149140
if datapoint.TestCaseId == "" {
150141
datapoint.TestCaseId = uuid.New().String()
151142
}
152143
_, err := tracing.RunInNewSpan(ctx, r.TracingState(), fmt.Sprintf("TestCase %s", datapoint.TestCaseId), "evaluator", false, datapoint,
153-
func(ctx context.Context, input Example) (*EvaluatorCallbackResponse, error) {
144+
func(ctx context.Context, input *Example) (*EvaluatorCallbackResponse, error) {
154145
traceId := trace.SpanContextFromContext(ctx).TraceID().String()
155146
spanId := trace.SpanContextFromContext(ctx).SpanID().String()
156147
callbackRequest := EvaluatorCallbackRequest{
157-
Input: input,
148+
Input: *input,
158149
Options: req.Options,
159150
}
160151
evaluatorResponse, err := eval(ctx, &callbackRequest)
@@ -202,66 +193,40 @@ func DefineBatchEvaluator(r *registry.Registry, provider, name string, options *
202193
metadataMap["evaluatorDisplayName"] = options.DisplayName
203194
metadataMap["evaluatorDefinition"] = options.Definition
204195

205-
return (*evaluatorActionDef)(core.DefineAction(r, provider, name, atype.Evaluator, map[string]any{"evaluator": metadataMap}, batchEval)), nil
206-
}
207-
208-
// IsDefinedEvaluator reports whether an [Evaluator] is defined.
209-
func IsDefinedEvaluator(r *registry.Registry, provider, name string) bool {
210-
return (*evaluatorActionDef)(core.LookupActionFor[*EvaluatorRequest, *EvaluatorResponse, struct{}](r, atype.Evaluator, provider, name)) != nil
196+
return (*evaluator)(core.DefineAction(r, provider, name, atype.Evaluator, map[string]any{"evaluator": metadataMap}, batchEval)), nil
211197
}
212198

213199
// LookupEvaluator looks up an [Evaluator] registered by [DefineEvaluator].
214200
// It returns nil if the evaluator was not defined.
215201
func LookupEvaluator(r *registry.Registry, provider, name string) Evaluator {
216-
return (*evaluatorActionDef)(core.LookupActionFor[*EvaluatorRequest, *EvaluatorResponse, struct{}](r, atype.Evaluator, provider, name))
217-
}
218-
219-
// EvaluateOption configures params of the Embed call.
220-
type EvaluateOption func(req *EvaluatorRequest) error
221-
222-
// WithEvaluateDataset set the dataset on [EvaluatorRequest]
223-
func WithEvaluateDataset(dataset *Dataset) EvaluateOption {
224-
return func(req *EvaluatorRequest) error {
225-
req.Dataset = dataset
226-
return nil
227-
}
228-
}
229-
230-
// WithEvaluateId set evaluation ID on [EvaluatorRequest]
231-
func WithEvaluateId(evaluationId string) EvaluateOption {
232-
return func(req *EvaluatorRequest) error {
233-
req.EvaluationId = evaluationId
234-
return nil
235-
}
236-
}
237-
238-
// WithEvaluateOptions set evaluator options on [EvaluatorRequest]
239-
func WithEvaluateOptions(opts any) EvaluateOption {
240-
return func(req *EvaluatorRequest) error {
241-
req.Options = opts
242-
return nil
243-
}
202+
return (*evaluator)(core.LookupActionFor[*EvaluatorRequest, *EvaluatorResponse, struct{}](r, atype.Evaluator, provider, name))
244203
}
245204

246205
// Evaluate calls the retrivers with provided options.
247-
func Evaluate(ctx context.Context, r Evaluator, opts ...EvaluateOption) (*EvaluatorResponse, error) {
248-
req := &EvaluatorRequest{}
249-
for _, with := range opts {
250-
err := with(req)
206+
func Evaluate(ctx context.Context, r Evaluator, opts ...EvaluatorOption) (*EvaluatorResponse, error) {
207+
evalOpts := &evaluatorOptions{}
208+
for _, opt := range opts {
209+
err := opt.applyEvaluator(evalOpts)
251210
if err != nil {
252211
return nil, err
253212
}
254213
}
214+
215+
req := &EvaluatorRequest{
216+
Dataset: evalOpts.Dataset,
217+
EvaluationId: evalOpts.ID,
218+
Options: evalOpts.Config,
219+
}
220+
255221
return r.Evaluate(ctx, req)
256222
}
257223

258-
func (r *evaluatorActionDef) Name() string { return (*evaluatorAction)(r).Name() }
224+
// Name returns the name of the evaluator.
225+
func (e evaluator) Name() string {
226+
return (*core.ActionDef[*EvaluatorRequest, *EvaluatorResponse, struct{}])(&e).Name()
227+
}
259228

260229
// Evaluate runs the given [Evaluator].
261-
func (e *evaluatorActionDef) Evaluate(ctx context.Context, req *EvaluatorRequest) (*EvaluatorResponse, error) {
262-
if e == nil {
263-
return nil, errors.New("Evaluator called on a nil Evaluator; check that all evaluators are defined")
264-
}
265-
a := (*core.ActionDef[*EvaluatorRequest, *EvaluatorResponse, struct{}])(e)
266-
return a.Run(ctx, req, nil)
230+
func (e evaluator) Evaluate(ctx context.Context, req *EvaluatorRequest) (*EvaluatorResponse, error) {
231+
return (*core.ActionDef[*EvaluatorRequest, *EvaluatorResponse, struct{}])(&e).Run(ctx, req, nil)
267232
}

go/ai/evaluator_test.go

+9-38
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,10 @@ var testEvalFunc = func(ctx context.Context, req *EvaluatorCallbackRequest) (*Ev
4242

4343
var testBatchEvalFunc = func(ctx context.Context, req *EvaluatorRequest) (*EvaluatorResponse, error) {
4444
var evalResponses []EvaluationResult
45-
dataset := *req.Dataset
46-
for i := 0; i < len(dataset); i++ {
47-
input := dataset[i]
48-
fmt.Printf("%+v\n", input)
45+
for _, datapoint := range req.Dataset {
46+
fmt.Printf("%+v\n", datapoint)
4947
m := make(map[string]any)
50-
m["reasoning"] = fmt.Sprintf("batch of cookies, %s", input.Input)
48+
m["reasoning"] = fmt.Sprintf("batch of cookies, %s", datapoint.Input)
5149
m["options"] = req.Options
5250
score := Score{
5351
Id: "testScore",
@@ -56,7 +54,7 @@ var testBatchEvalFunc = func(ctx context.Context, req *EvaluatorRequest) (*Evalu
5654
Details: m,
5755
}
5856
callbackResponse := EvaluationResult{
59-
TestCaseId: input.TestCaseId,
57+
TestCaseId: datapoint.TestCaseId,
6058
Evaluation: []Score{score},
6159
}
6260
evalResponses = append(evalResponses, callbackResponse)
@@ -74,7 +72,7 @@ var evalOptions = EvaluatorOptions{
7472
IsBilled: false,
7573
}
7674

77-
var dataset = Dataset{
75+
var dataset = []*Example{
7876
{
7977
Input: "hello world",
8078
},
@@ -84,7 +82,7 @@ var dataset = Dataset{
8482
}
8583

8684
var testRequest = EvaluatorRequest{
87-
Dataset: &dataset,
85+
Dataset: dataset,
8886
EvaluationId: "testrun",
8987
Options: "test-options",
9088
}
@@ -162,33 +160,6 @@ func TestFailingEvaluator(t *testing.T) {
162160
}
163161
}
164162

165-
func TestIsDefinedEvaluator(t *testing.T) {
166-
r, err := registry.New()
167-
if err != nil {
168-
t.Fatal(err)
169-
}
170-
171-
_, err = DefineEvaluator(r, "test", "testEvaluator", &evalOptions, testEvalFunc)
172-
if err != nil {
173-
t.Fatal(err)
174-
}
175-
_, err = DefineBatchEvaluator(r, "test", "testBatchEvaluator", &evalOptions, testBatchEvalFunc)
176-
if err != nil {
177-
t.Fatal(err)
178-
}
179-
180-
if got, want := IsDefinedEvaluator(r, "test", "testEvaluator"), true; got != want {
181-
t.Errorf("got %v, want %v", got, want)
182-
}
183-
if got, want := IsDefinedEvaluator(r, "test", "testBatchEvaluator"), true; got != want {
184-
t.Errorf("got %v, want %v", got, want)
185-
}
186-
if got, want := IsDefinedEvaluator(r, "test", "fakefakefake"), false; got != want {
187-
t.Errorf("got %v, want %v", got, want)
188-
}
189-
190-
}
191-
192163
func TestLookupEvaluator(t *testing.T) {
193164
r, err := registry.New()
194165
if err != nil {
@@ -224,9 +195,9 @@ func TestEvaluate(t *testing.T) {
224195
}
225196

226197
resp, err := Evaluate(context.Background(), evalAction,
227-
WithEvaluateDataset(&dataset),
228-
WithEvaluateId("testrun"),
229-
WithEvaluateOptions("test-options"))
198+
WithDataset(dataset...),
199+
WithID("testrun"),
200+
WithConfig("test-options"))
230201
if err != nil {
231202
t.Fatal(err)
232203
}

0 commit comments

Comments
 (0)