Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backend/kitex_gen/coze/loop/evaluation/domain/expt/expt.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions backend/modules/evaluation/consts/eval_target.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright (c) 2025 coze-dev Authors
// SPDX-License-Identifier: Apache-2.0

package consts

import (
"github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/evaluation/domain/expt"
)

const (
InputFieldKeyPromptUserQuery = expt.PromptUserQueryFieldKey
)
49 changes: 49 additions & 0 deletions backend/modules/evaluation/consts/eval_target_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (c) 2025 coze-dev Authors
// SPDX-License-Identifier: Apache-2.0

package consts

import (
"testing"

"github.com/stretchr/testify/assert"

"github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/evaluation/domain/expt"
)

func TestInputFieldKeyPromptUserQuery(t *testing.T) {
tests := []struct {
name string
expected string
}{
{
name: "verify InputFieldKeyPromptUserQuery constant value",
expected: "builtin_prompt_user_query",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, InputFieldKeyPromptUserQuery)
})
}
}

func TestInputFieldKeyPromptUserQueryConsistency(t *testing.T) {
tests := []struct {
name string
expected string
}{
{
name: "verify consistency with expt.PromptUserQueryFieldKey",
expected: expt.PromptUserQueryFieldKey,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, InputFieldKeyPromptUserQuery)
assert.Equal(t, InputFieldKeyPromptUserQuery, expt.PromptUserQueryFieldKey)
})
}
}
8 changes: 5 additions & 3 deletions backend/modules/evaluation/domain/component/rpc/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@ type ExecutePromptParam struct {
PromptVersion string
Variables []*entity.VariableVal
History []*entity.Message
UserQuery *entity.Message
RuntimeParam *string
}

type ExecutePromptResult struct {
Content *string `json:"content,omitempty"`
ToolCalls []*entity.ToolCall `json:"tool_calls,omitempty"`
TokenUsage *entity.TokenUsage `json:"token_usage,omitempty"`
Content *string `json:"content,omitempty"`
ToolCalls []*entity.ToolCall `json:"tool_calls,omitempty"`
TokenUsage *entity.TokenUsage `json:"token_usage,omitempty"`
MultiContent *entity.Content `json:"multi_content,omitempty"`
}

type GetPromptParams struct {
Expand Down
157 changes: 157 additions & 0 deletions backend/modules/evaluation/domain/component/rpc/prompt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package rpc
import (
"testing"

"github.com/bytedance/gg/gptr"
"github.com/stretchr/testify/assert"

"github.com/coze-dev/coze-loop/backend/modules/evaluation/domain/entity"
Expand Down Expand Up @@ -119,6 +120,162 @@ func TestExecutePromptParam_Structure_Integrity(t *testing.T) {
}
}

func TestExecutePromptParam_UserQuery(t *testing.T) {
tests := []struct {
name string
userQuery *entity.Message
wantNil bool
}{
{
name: "with_user_query_text_message",
userQuery: &entity.Message{
Role: entity.RoleUser,
Content: &entity.Content{
ContentType: gptr.Of(entity.ContentTypeText),
Text: gptr.Of("test user query"),
},
},
wantNil: false,
},
{
name: "with_user_query_multipart_message",
userQuery: &entity.Message{
Role: entity.RoleUser,
Content: &entity.Content{
ContentType: gptr.Of(entity.ContentTypeMultipart),
MultiPart: []*entity.Content{
{
ContentType: gptr.Of(entity.ContentTypeText),
Text: gptr.Of("part 1"),
},
{
ContentType: gptr.Of(entity.ContentTypeImage),
Image: &entity.Image{
URL: gptr.Of("http://example.com/image.jpg"),
},
},
},
},
},
wantNil: false,
},
{
name: "without_user_query_nil",
userQuery: nil,
wantNil: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
param := &ExecutePromptParam{
PromptID: 12345,
PromptVersion: "v1.0",
Variables: []*entity.VariableVal{},
History: []*entity.Message{},
UserQuery: tt.userQuery,
}

// Test that UserQuery field is correctly set
if tt.wantNil {
assert.Nil(t, param.UserQuery)
} else {
assert.NotNil(t, param.UserQuery)
assert.Equal(t, tt.userQuery, param.UserQuery)
assert.Equal(t, entity.RoleUser, param.UserQuery.Role)
}
})
}
}

func TestExecutePromptResult_MultiContent(t *testing.T) {
tests := []struct {
name string
content *string
toolCalls []*entity.ToolCall
tokenUsage *entity.TokenUsage
multiContent *entity.Content
expectedType entity.ContentType
expectedText string
expectedMulti bool
}{
{
name: "with_multi_content_text",
multiContent: &entity.Content{
ContentType: gptr.Of(entity.ContentTypeText),
Text: gptr.Of("multi content text"),
},
expectedType: entity.ContentTypeText,
expectedText: "multi content text",
expectedMulti: true,
},
{
name: "with_multi_content_multipart",
multiContent: &entity.Content{
ContentType: gptr.Of(entity.ContentTypeMultipart),
MultiPart: []*entity.Content{
{
ContentType: gptr.Of(entity.ContentTypeText),
Text: gptr.Of("text part"),
},
{
ContentType: gptr.Of(entity.ContentTypeImage),
Image: &entity.Image{
URL: gptr.Of("http://example.com/image.jpg"),
},
},
},
},
expectedType: entity.ContentTypeMultipart,
expectedMulti: true,
},
{
name: "without_multi_content_nil",
multiContent: nil,
expectedMulti: false,
},
{
name: "with_content_and_multi_content",
content: gptr.Of("regular content"),
multiContent: &entity.Content{
ContentType: gptr.Of(entity.ContentTypeText),
Text: gptr.Of("multi content"),
},
expectedType: entity.ContentTypeText,
expectedText: "multi content",
expectedMulti: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := &ExecutePromptResult{
Content: tt.content,
ToolCalls: tt.toolCalls,
TokenUsage: tt.tokenUsage,
MultiContent: tt.multiContent,
}

// Test that MultiContent field is correctly set
if tt.expectedMulti {
assert.NotNil(t, result.MultiContent)
assert.Equal(t, tt.multiContent, result.MultiContent)
assert.Equal(t, tt.expectedType, gptr.Indirect(result.MultiContent.ContentType))
if tt.expectedText != "" {
assert.Equal(t, tt.expectedText, gptr.Indirect(result.MultiContent.Text))
}
} else {
assert.Nil(t, result.MultiContent)
}

// Test that other fields are preserved
assert.Equal(t, tt.content, result.Content)
assert.Equal(t, tt.toolCalls, result.ToolCalls)
assert.Equal(t, tt.tokenUsage, result.TokenUsage)
})
}
}

func TestExecutePromptParam_RuntimeParam_JSON_Scenarios(t *testing.T) {
tests := []struct {
name string
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ func (t *PromptSourceEvalTargetServiceImpl) Execute(ctx context.Context, spaceID
}
vals := make([]*entity.VariableVal, 0)
for key, content := range param.Input.InputFields {
if key == consts.InputFieldKeyPromptUserQuery {
exePromptParam.UserQuery = &entity.Message{
Role: entity.RoleUser,
Content: content,
}
delete(param.Input.InputFields, key)
continue
}
if content != nil {
variable := &entity.VariableVal{
Key: gptr.Of(key),
Expand Down Expand Up @@ -118,24 +126,29 @@ func (t *PromptSourceEvalTargetServiceImpl) Execute(ctx context.Context, spaceID
return evaluatorOutputData, entity.EvalTargetRunStatusFail, err
}

var outputStr string

if executePromptResult == nil {
outputStr = ""
} else if executePromptResult.Content != nil {
outputStr = *executePromptResult.Content
} else if executePromptResult.ToolCalls != nil {
outputStr, err = json.MarshalString(executePromptResult.ToolCalls)
var outputContent *entity.Content
if executePromptResult != nil && executePromptResult.MultiContent != nil {
outputContent = executePromptResult.MultiContent
} else {
outputStr = ""
}

evaluatorOutputData.OutputFields = map[string]*entity.Content{
consts.OutputSchemaKey: {
var outputStr string
if executePromptResult == nil {
outputStr = ""
} else if executePromptResult.Content != nil {
outputStr = *executePromptResult.Content
} else if executePromptResult.ToolCalls != nil {
outputStr, err = json.MarshalString(executePromptResult.ToolCalls)
} else {
outputStr = ""
}
outputContent = &entity.Content{
ContentType: gptr.Of(entity.ContentTypeText),
Format: gptr.Of(entity.Markdown),
Text: &outputStr,
},
}
}

evaluatorOutputData.OutputFields = map[string]*entity.Content{
consts.OutputSchemaKey: outputContent,
}

if executePromptResult != nil && executePromptResult.TokenUsage != nil {
Expand Down Expand Up @@ -207,6 +220,11 @@ func (t *PromptSourceEvalTargetServiceImpl) BuildBySource(ctx context.Context, s
JsonSchema: gptr.Of(jsonschema),
})
}
inputSchema = append(inputSchema, &entity.ArgsSchema{
Key: gptr.Of(consts.InputFieldKeyPromptUserQuery),
SupportContentTypes: []entity.ContentType{entity.ContentTypeText, entity.ContentTypeImage, entity.ContentTypeMultipart},
JsonSchema: gptr.Of(consts.StringJsonSchema),
})
}
userIDInContext := session.UserIDInCtxOrEmpty(ctx)
do := &entity.EvalTarget{
Expand Down Expand Up @@ -406,6 +424,20 @@ func (t *PromptSourceEvalTargetServiceImpl) PackSourceVersionInfo(ctx context.Co
PromptID: do.EvalTargetVersion.Prompt.PromptID,
Version: &do.EvalTargetVersion.SourceTargetVersion,
})
existUserQueryKey := false
for _, schema := range do.EvalTargetVersion.InputSchema {
if gptr.Indirect(schema.Key) == consts.InputFieldKeyPromptUserQuery {
existUserQueryKey = true
break
}
}
if !existUserQueryKey { // compatibility with historical data
do.EvalTargetVersion.InputSchema = append(do.EvalTargetVersion.InputSchema, &entity.ArgsSchema{
Key: gptr.Of(consts.InputFieldKeyPromptUserQuery),
SupportContentTypes: []entity.ContentType{entity.ContentTypeText, entity.ContentTypeImage, entity.ContentTypeMultipart},
JsonSchema: gptr.Of(consts.StringJsonSchema),
})
}
}
if len(promptQueries) == 0 {
return nil
Expand Down
Loading
Loading