Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(go)!: Updated all primitives to unified options. #2550

Merged
merged 13 commits into from
Apr 7, 2025
94 changes: 20 additions & 74 deletions go/ai/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package ai

import (
"context"
"errors"
"fmt"

"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/internal/atype"
Expand All @@ -33,29 +33,8 @@ type Embedder interface {
Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error)
}

// An embedderActionDef is used to convert a document to a
// multidimensional vector.
type embedderActionDef core.ActionDef[*EmbedRequest, *EmbedResponse, struct{}]

type embedderAction = core.ActionDef[*EmbedRequest, *EmbedResponse, struct{}]

// EmbedRequest is the data we pass to convert one or more documents
// to a multidimensional vector.
type EmbedRequest struct {
Documents []*Document `json:"input"`
Options any `json:"options,omitempty"`
}

type EmbedResponse struct {
// One embedding for each Document in the request, in the same order.
Embeddings []*DocumentEmbedding `json:"embeddings"`
}

// DocumentEmbedding holds emdedding information about a single document.
type DocumentEmbedding struct {
// The vector for the embedding.
Embedding []float32 `json:"embedding"`
}
// An embedder is used to convert a document to a multidimensional vector.
type embedder core.ActionDef[*EmbedRequest, *EmbedResponse, struct{}]

// DefineEmbedder registers the given embed function as an action, and returns an
// [Embedder] that runs it.
Expand All @@ -64,12 +43,7 @@ func DefineEmbedder(
provider, name string,
embed func(context.Context, *EmbedRequest) (*EmbedResponse, error),
) Embedder {
return (*embedderActionDef)(core.DefineAction(r, provider, name, atype.Embedder, nil, embed))
}

// IsDefinedEmbedder reports whether an embedder is defined.
func IsDefinedEmbedder(r *registry.Registry, provider, name string) bool {
return LookupEmbedder(r, provider, name) != nil
return (*embedder)(core.DefineAction(r, provider, name, atype.Embedder, nil, embed))
}

// LookupEmbedder looks up an [Embedder] registered by [DefineEmbedder].
Expand All @@ -79,61 +53,33 @@ func LookupEmbedder(r *registry.Registry, provider, name string) Embedder {
if action == nil {
return nil
}
return (*embedderActionDef)(action)
}

// Embed runs the given [Embedder].
func (e *embedderActionDef) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) {
if e == nil {
return nil, errors.New("Embed called on a nil Embedder; check that all embedders are defined")
}
a := (*core.ActionDef[*EmbedRequest, *EmbedResponse, struct{}])(e)
return a.Run(ctx, req, nil)
return (*embedder)(action)
}

func (e *embedderActionDef) Name() string {
return (*embedderAction)(e).Name()
// Name returns the name of the embedder.
func (e *embedder) Name() string {
return (*core.ActionDef[*EmbedRequest, *EmbedResponse, struct{}])(e).Name()
}

// EmbedOption configures params of the Embed call.
type EmbedOption func(req *EmbedRequest) error

// WithEmbedOptions set embedder options on [EmbedRequest]
func WithEmbedOptions(opts any) EmbedOption {
return func(req *EmbedRequest) error {
req.Options = opts
return nil
}
// Embed runs the given [Embedder].
func (e *embedder) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) {
return (*core.ActionDef[*EmbedRequest, *EmbedResponse, struct{}])(e).Run(ctx, req, nil)
}

// WithEmbedText adds simple text documents to [EmbedRequest]
func WithEmbedText(text ...string) EmbedOption {
return func(req *EmbedRequest) error {
var docs []*Document
for _, p := range text {
docs = append(docs, DocumentFromText(p, nil))
// Embed invokes the embedder with provided options.
func Embed(ctx context.Context, e Embedder, opts ...EmbedderOption) (*EmbedResponse, error) {
embedOpts := &embedderOptions{}
for _, opt := range opts {
if err := opt.applyEmbedder(embedOpts); err != nil {
return nil, fmt.Errorf("ai.Embed: error applying options: %w", err)
}
req.Documents = append(req.Documents, docs...)
return nil
}
}

// WithEmbedDocs adds documents to [EmbedRequest]
func WithEmbedDocs(docs ...*Document) EmbedOption {
return func(req *EmbedRequest) error {
req.Documents = append(req.Documents, docs...)
return nil
req := &EmbedRequest{
Input: embedOpts.Documents,
Options: embedOpts.Config,
}
}

// Embed invokes the embedder with provided options.
func Embed(ctx context.Context, e Embedder, opts ...EmbedOption) (*EmbedResponse, error) {
req := &EmbedRequest{}
for _, with := range opts {
err := with(req)
if err != nil {
return nil, err
}
}
return e.Embed(ctx, req)
}
89 changes: 27 additions & 62 deletions go/ai/evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,7 @@ type Evaluator interface {
Evaluate(ctx context.Context, req *EvaluatorRequest) (*EvaluatorResponse, error)
}

type (
evaluatorActionDef core.ActionDef[*EvaluatorRequest, *EvaluatorResponse, struct{}]

evaluatorAction = core.ActionDef[*EvaluatorRequest, *EvaluatorResponse, struct{}]
)
type evaluator core.ActionDef[*EvaluatorRequest, *EvaluatorResponse, struct{}]

// Example is a single example that requires evaluation
type Example struct {
Expand All @@ -54,15 +50,12 @@ type Example struct {
TraceIds []string `json:"traceIds,omitempty"`
}

// Dataset is a collection of [Example]
type Dataset = []Example

// EvaluatorRequest is the data we pass to evaluate a dataset.
// The Options field is specific to the actual evaluator implementation.
type EvaluatorRequest struct {
Dataset *Dataset `json:"dataset"`
EvaluationId string `json:"evalRunId"`
Options any `json:"options,omitempty"`
Dataset []*Example `json:"dataset"`
EvaluationId string `json:"evalRunId"`
Options any `json:"options,omitempty"`
}

// ScoreStatus is an enum used to indicate if a Score has passed or failed. This
Expand Down Expand Up @@ -141,20 +134,18 @@ func DefineEvaluator(r *registry.Registry, provider, name string, options *Evalu
metadataMap["evaluatorDisplayName"] = options.DisplayName
metadataMap["evaluatorDefinition"] = options.Definition

actionDef := (*evaluatorActionDef)(core.DefineAction(r, provider, name, atype.Evaluator, map[string]any{"evaluator": metadataMap}, func(ctx context.Context, req *EvaluatorRequest) (output *EvaluatorResponse, err error) {
actionDef := (*evaluator)(core.DefineAction(r, provider, name, atype.Evaluator, map[string]any{"evaluator": metadataMap}, func(ctx context.Context, req *EvaluatorRequest) (output *EvaluatorResponse, err error) {
var evalResponses []EvaluationResult
dataset := *req.Dataset
for i := 0; i < len(dataset); i++ {
datapoint := dataset[i]
for _, datapoint := range req.Dataset {
if datapoint.TestCaseId == "" {
datapoint.TestCaseId = uuid.New().String()
}
_, err := tracing.RunInNewSpan(ctx, r.TracingState(), fmt.Sprintf("TestCase %s", datapoint.TestCaseId), "evaluator", false, datapoint,
func(ctx context.Context, input Example) (*EvaluatorCallbackResponse, error) {
func(ctx context.Context, input *Example) (*EvaluatorCallbackResponse, error) {
traceId := trace.SpanContextFromContext(ctx).TraceID().String()
spanId := trace.SpanContextFromContext(ctx).SpanID().String()
callbackRequest := EvaluatorCallbackRequest{
Input: input,
Input: *input,
Options: req.Options,
}
evaluatorResponse, err := eval(ctx, &callbackRequest)
Expand Down Expand Up @@ -202,66 +193,40 @@ func DefineBatchEvaluator(r *registry.Registry, provider, name string, options *
metadataMap["evaluatorDisplayName"] = options.DisplayName
metadataMap["evaluatorDefinition"] = options.Definition

return (*evaluatorActionDef)(core.DefineAction(r, provider, name, atype.Evaluator, map[string]any{"evaluator": metadataMap}, batchEval)), nil
}

// IsDefinedEvaluator reports whether an [Evaluator] is defined.
func IsDefinedEvaluator(r *registry.Registry, provider, name string) bool {
return (*evaluatorActionDef)(core.LookupActionFor[*EvaluatorRequest, *EvaluatorResponse, struct{}](r, atype.Evaluator, provider, name)) != nil
return (*evaluator)(core.DefineAction(r, provider, name, atype.Evaluator, map[string]any{"evaluator": metadataMap}, batchEval)), nil
}

// LookupEvaluator looks up an [Evaluator] registered by [DefineEvaluator].
// It returns nil if the evaluator was not defined.
func LookupEvaluator(r *registry.Registry, provider, name string) Evaluator {
return (*evaluatorActionDef)(core.LookupActionFor[*EvaluatorRequest, *EvaluatorResponse, struct{}](r, atype.Evaluator, provider, name))
}

// EvaluateOption configures params of the Embed call.
type EvaluateOption func(req *EvaluatorRequest) error

// WithEvaluateDataset set the dataset on [EvaluatorRequest]
func WithEvaluateDataset(dataset *Dataset) EvaluateOption {
return func(req *EvaluatorRequest) error {
req.Dataset = dataset
return nil
}
}

// WithEvaluateId set evaluation ID on [EvaluatorRequest]
func WithEvaluateId(evaluationId string) EvaluateOption {
return func(req *EvaluatorRequest) error {
req.EvaluationId = evaluationId
return nil
}
}

// WithEvaluateOptions set evaluator options on [EvaluatorRequest]
func WithEvaluateOptions(opts any) EvaluateOption {
return func(req *EvaluatorRequest) error {
req.Options = opts
return nil
}
return (*evaluator)(core.LookupActionFor[*EvaluatorRequest, *EvaluatorResponse, struct{}](r, atype.Evaluator, provider, name))
}

// Evaluate calls the retrivers with provided options.
func Evaluate(ctx context.Context, r Evaluator, opts ...EvaluateOption) (*EvaluatorResponse, error) {
req := &EvaluatorRequest{}
for _, with := range opts {
err := with(req)
func Evaluate(ctx context.Context, r Evaluator, opts ...EvaluatorOption) (*EvaluatorResponse, error) {
evalOpts := &evaluatorOptions{}
for _, opt := range opts {
err := opt.applyEvaluator(evalOpts)
if err != nil {
return nil, err
}
}

req := &EvaluatorRequest{
Dataset: evalOpts.Dataset,
EvaluationId: evalOpts.ID,
Options: evalOpts.Config,
}

return r.Evaluate(ctx, req)
}

func (r *evaluatorActionDef) Name() string { return (*evaluatorAction)(r).Name() }
// Name returns the name of the evaluator.
func (e evaluator) Name() string {
return (*core.ActionDef[*EvaluatorRequest, *EvaluatorResponse, struct{}])(&e).Name()
}

// Evaluate runs the given [Evaluator].
func (e *evaluatorActionDef) Evaluate(ctx context.Context, req *EvaluatorRequest) (*EvaluatorResponse, error) {
if e == nil {
return nil, errors.New("Evaluator called on a nil Evaluator; check that all evaluators are defined")
}
a := (*core.ActionDef[*EvaluatorRequest, *EvaluatorResponse, struct{}])(e)
return a.Run(ctx, req, nil)
func (e evaluator) Evaluate(ctx context.Context, req *EvaluatorRequest) (*EvaluatorResponse, error) {
return (*core.ActionDef[*EvaluatorRequest, *EvaluatorResponse, struct{}])(&e).Run(ctx, req, nil)
}
47 changes: 9 additions & 38 deletions go/ai/evaluator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,10 @@ var testEvalFunc = func(ctx context.Context, req *EvaluatorCallbackRequest) (*Ev

var testBatchEvalFunc = func(ctx context.Context, req *EvaluatorRequest) (*EvaluatorResponse, error) {
var evalResponses []EvaluationResult
dataset := *req.Dataset
for i := 0; i < len(dataset); i++ {
input := dataset[i]
fmt.Printf("%+v\n", input)
for _, datapoint := range req.Dataset {
fmt.Printf("%+v\n", datapoint)
m := make(map[string]any)
m["reasoning"] = fmt.Sprintf("batch of cookies, %s", input.Input)
m["reasoning"] = fmt.Sprintf("batch of cookies, %s", datapoint.Input)
m["options"] = req.Options
score := Score{
Id: "testScore",
Expand All @@ -56,7 +54,7 @@ var testBatchEvalFunc = func(ctx context.Context, req *EvaluatorRequest) (*Evalu
Details: m,
}
callbackResponse := EvaluationResult{
TestCaseId: input.TestCaseId,
TestCaseId: datapoint.TestCaseId,
Evaluation: []Score{score},
}
evalResponses = append(evalResponses, callbackResponse)
Expand All @@ -74,7 +72,7 @@ var evalOptions = EvaluatorOptions{
IsBilled: false,
}

var dataset = Dataset{
var dataset = []*Example{
{
Input: "hello world",
},
Expand All @@ -84,7 +82,7 @@ var dataset = Dataset{
}

var testRequest = EvaluatorRequest{
Dataset: &dataset,
Dataset: dataset,
EvaluationId: "testrun",
Options: "test-options",
}
Expand Down Expand Up @@ -162,33 +160,6 @@ func TestFailingEvaluator(t *testing.T) {
}
}

func TestIsDefinedEvaluator(t *testing.T) {
r, err := registry.New()
if err != nil {
t.Fatal(err)
}

_, err = DefineEvaluator(r, "test", "testEvaluator", &evalOptions, testEvalFunc)
if err != nil {
t.Fatal(err)
}
_, err = DefineBatchEvaluator(r, "test", "testBatchEvaluator", &evalOptions, testBatchEvalFunc)
if err != nil {
t.Fatal(err)
}

if got, want := IsDefinedEvaluator(r, "test", "testEvaluator"), true; got != want {
t.Errorf("got %v, want %v", got, want)
}
if got, want := IsDefinedEvaluator(r, "test", "testBatchEvaluator"), true; got != want {
t.Errorf("got %v, want %v", got, want)
}
if got, want := IsDefinedEvaluator(r, "test", "fakefakefake"), false; got != want {
t.Errorf("got %v, want %v", got, want)
}

}

func TestLookupEvaluator(t *testing.T) {
r, err := registry.New()
if err != nil {
Expand Down Expand Up @@ -224,9 +195,9 @@ func TestEvaluate(t *testing.T) {
}

resp, err := Evaluate(context.Background(), evalAction,
WithEvaluateDataset(&dataset),
WithEvaluateId("testrun"),
WithEvaluateOptions("test-options"))
WithDataset(dataset...),
WithID("testrun"),
WithConfig("test-options"))
if err != nil {
t.Fatal(err)
}
Expand Down
Loading
Loading