Skip to content

Commit

Permalink
[anthropic] Support: provider-defined tools, token usage, cache-control
Browse files Browse the repository at this point in the history
  • Loading branch information
savil committed Feb 13, 2025
1 parent 0672790 commit 323d04c
Show file tree
Hide file tree
Showing 5 changed files with 260 additions and 50 deletions.
131 changes: 102 additions & 29 deletions llms/anthropic/anthropicllm.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,13 @@ func generateMessagesContent(ctx context.Context, o *LLM, messages []llms.Messag
choices[i] = &llms.ContentChoice{
Content: textContent.Text,
StopReason: result.StopReason,
// TODO: this is not fully accurate. The Usage data covers all the Choices
// in the response, not just this Choice.
GenerationInfo: map[string]any{
"InputTokens": result.Usage.InputTokens,
"OutputTokens": result.Usage.OutputTokens,
"InputTokens": result.Usage.InputTokens,
"OutputTokens": result.Usage.OutputTokens,
"CacheCreationInputTokens": result.Usage.CacheCreationInputTokens,
"CacheReadInputTokens": result.Usage.CacheReadInputTokens,
},
}
} else {
Expand All @@ -189,9 +193,13 @@ func generateMessagesContent(ctx context.Context, o *LLM, messages []llms.Messag
},
},
StopReason: result.StopReason,
// TODO: this is not fully accurate. The Usage data covers all the Choices
// in the response, not just this Choice.
GenerationInfo: map[string]any{
"InputTokens": result.Usage.InputTokens,
"OutputTokens": result.Usage.OutputTokens,
"InputTokens": result.Usage.InputTokens,
"OutputTokens": result.Usage.OutputTokens,
"CacheCreationInputTokens": result.Usage.CacheCreationInputTokens,
"CacheReadInputTokens": result.Usage.CacheReadInputTokens,
},
}
} else {
Expand All @@ -211,82 +219,125 @@ func generateMessagesContent(ctx context.Context, o *LLM, messages []llms.Messag
func toolsToTools(tools []llms.Tool) []anthropicclient.Tool {
toolReq := make([]anthropicclient.Tool, len(tools))
for i, tool := range tools {
toolReq[i] = anthropicclient.Tool{
anthropicTool := anthropicclient.Tool{
Name: tool.Function.Name,
Description: tool.Function.Description,
InputSchema: tool.Function.Parameters,

Type: tool.Type,
}
if tool.Function.ProviderMetadata != nil {
if displayHeightPx, ok := tool.Function.ProviderMetadata["display_height_px"]; ok {
anthropicTool.DisplayHeightPx = displayHeightPx.(int)
}
if displayWidthPx, ok := tool.Function.ProviderMetadata["display_width_px"]; ok {
anthropicTool.DisplayWidthPx = displayWidthPx.(int)
}
if displayNumber, ok := tool.Function.ProviderMetadata["display_number"]; ok {
anthropicTool.DisplayNumber = displayNumber.(int)
}
if cacheControl, ok := tool.Function.ProviderMetadata["cache_control"]; ok {
anthropicTool.CacheControl = struct {
Type string "json:\"type,omitempty\""
}{
Type: cacheControl.(string),
}
}
}

toolReq[i] = anthropicTool
}
return toolReq
}

func processMessages(messages []llms.MessageContent) ([]anthropicclient.ChatMessage, string, error) {
func processMessages(messages []llms.MessageContent) ([]anthropicclient.ChatMessage, []anthropicclient.TextContent, error) {
chatMessages := make([]anthropicclient.ChatMessage, 0, len(messages))
systemPrompt := ""
systemPrompt := make([]anthropicclient.TextContent, 0)
for _, msg := range messages {
switch msg.Role {
case llms.ChatMessageTypeSystem:
content, err := handleSystemMessage(msg)
if err != nil {
return nil, "", fmt.Errorf("anthropic: failed to handle system message: %w", err)
return nil, nil, fmt.Errorf("anthropic: failed to handle system message: %w", err)
}
systemPrompt += content
systemPrompt = append(systemPrompt, content)
case llms.ChatMessageTypeHuman:
chatMessage, err := handleHumanMessage(msg)
if err != nil {
return nil, "", fmt.Errorf("anthropic: failed to handle human message: %w", err)
return nil, nil, fmt.Errorf("anthropic: failed to handle human message: %w", err)
}
chatMessages = append(chatMessages, chatMessage)
case llms.ChatMessageTypeAI:
chatMessage, err := handleAIMessage(msg)
if err != nil {
return nil, "", fmt.Errorf("anthropic: failed to handle AI message: %w", err)
return nil, nil, fmt.Errorf("anthropic: failed to handle AI message: %w", err)
}
chatMessages = append(chatMessages, chatMessage)
case llms.ChatMessageTypeTool:
chatMessage, err := handleToolMessage(msg)
if err != nil {
return nil, "", fmt.Errorf("anthropic: failed to handle tool message: %w", err)
return nil, nil, fmt.Errorf("anthropic: failed to handle tool message: %w", err)
}
chatMessages = append(chatMessages, chatMessage)
case llms.ChatMessageTypeGeneric, llms.ChatMessageTypeFunction:
return nil, "", fmt.Errorf("anthropic: %w: %v", ErrUnsupportedMessageType, msg.Role)
return nil, nil, fmt.Errorf("anthropic: %w: %v", ErrUnsupportedMessageType, msg.Role)
default:
return nil, "", fmt.Errorf("anthropic: %w: %v", ErrUnsupportedMessageType, msg.Role)
return nil, nil, fmt.Errorf("anthropic: %w: %v", ErrUnsupportedMessageType, msg.Role)
}
}
return chatMessages, systemPrompt, nil
}

func handleSystemMessage(msg llms.MessageContent) (string, error) {
func handleSystemMessage(msg llms.MessageContent) (anthropicclient.TextContent, error) {
if textContent, ok := msg.Parts[0].(llms.TextContent); ok {
return textContent.Text, nil
cacheControl, err := getCacheControl(textContent)
if err != nil {
return anthropicclient.TextContent{}, fmt.Errorf("anthropic: failed to get cache control: %w", err)
}
return anthropicclient.TextContent{
Type: "text",
Text: textContent.Text,
CacheControl: cacheControl,
}, nil
}
return "", fmt.Errorf("anthropic: %w for system message", ErrInvalidContentType)
return anthropicclient.TextContent{}, fmt.Errorf("anthropic: %w for system message", ErrInvalidContentType)
}

func handleHumanMessage(msg llms.MessageContent) (anthropicclient.ChatMessage, error) {
if textContent, ok := msg.Parts[0].(llms.TextContent); ok {
cacheControl, err := getCacheControl(textContent)
if err != nil {
return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: failed to get cache control: %w", err)
}
return anthropicclient.ChatMessage{
Role: RoleUser,
Content: textContent.Text,
Role: RoleUser,
Content: []anthropicclient.Content{&anthropicclient.TextContent{
Type: "text",
Text: textContent.Text,
CacheControl: cacheControl,
}},
}, nil
}
return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: %w for human message", ErrInvalidContentType)
}

func handleAIMessage(msg llms.MessageContent) (anthropicclient.ChatMessage, error) {
if toolCall, ok := msg.Parts[0].(llms.ToolCall); ok {
cacheControl, err := getCacheControl(toolCall)
if err != nil {
return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: failed to get cache control: %w", err)
}
var inputStruct map[string]interface{}
err := json.Unmarshal([]byte(toolCall.FunctionCall.Arguments), &inputStruct)
err = json.Unmarshal([]byte(toolCall.FunctionCall.Arguments), &inputStruct)
if err != nil {
return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: failed to unmarshal tool call arguments: %w", err)
}
toolUse := anthropicclient.ToolUseContent{
Type: "tool_use",
ID: toolCall.ID,
Name: toolCall.FunctionCall.Name,
Input: inputStruct,
Type: "tool_use",
ID: toolCall.ID,
Name: toolCall.FunctionCall.Name,
Input: inputStruct,
CacheControl: cacheControl,
}

return anthropicclient.ChatMessage{
Expand All @@ -295,11 +346,16 @@ func handleAIMessage(msg llms.MessageContent) (anthropicclient.ChatMessage, erro
}, nil
}
if textContent, ok := msg.Parts[0].(llms.TextContent); ok {
cacheControl, err := getCacheControl(textContent)
if err != nil {
return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: failed to get cache control: %w", err)
}
return anthropicclient.ChatMessage{
Role: RoleAssistant,
Content: []anthropicclient.Content{&anthropicclient.TextContent{
Type: "text",
Text: textContent.Text,
Type: "text",
Text: textContent.Text,
CacheControl: cacheControl,
}},
}, nil
}
Expand All @@ -314,10 +370,19 @@ type ToolResult struct {

func handleToolMessage(msg llms.MessageContent) (anthropicclient.ChatMessage, error) {
if toolCallResponse, ok := msg.Parts[0].(llms.ToolCallResponse); ok {
cacheControl, err := getCacheControl(toolCallResponse)
if err != nil {
return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: failed to get cache control: %w", err)
}
toolContent := anthropicclient.ToolResultContent{
Type: "tool_result",
ToolUseID: toolCallResponse.ToolCallID,
Content: toolCallResponse.Content,
Type: "tool_result",
ToolUseID: toolCallResponse.ToolCallID,
Content: toolCallResponse.Content,
CacheControl: cacheControl,
}

if toolCallResponse.MultiContent != nil {
toolContent.MultiContent = toolCallResponse.MultiContent
}

return anthropicclient.ChatMessage{
Expand All @@ -327,3 +392,11 @@ func handleToolMessage(msg llms.MessageContent) (anthropicclient.ChatMessage, er
}
return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: %w for tool message", ErrInvalidContentType)
}

func getCacheControl(part llms.ContentPart) (map[string]string, error) {
md := part.GetProviderMetadata()
if md == nil {
return nil, nil
}
return md["cache_control"].(map[string]string), nil
}
2 changes: 1 addition & 1 deletion llms/anthropic/internal/anthropicclient/anthropicclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ func (c *Client) CreateCompletion(ctx context.Context, r *CompletionRequest) (*C
type MessageRequest struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
System string `json:"system,omitempty"`
System any `json:"system,omitempty"` // string or []TextContent
Temperature float64 `json:"temperature"`
MaxTokens int `json:"max_tokens,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Expand Down
83 changes: 70 additions & 13 deletions llms/anthropic/internal/anthropicclient/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"log"
"net/http"
"strings"

"github.com/tmc/langchaingo/llms"
)

var (
Expand All @@ -27,14 +29,14 @@ var (
)

type ChatMessage struct {
Role string `json:"role"`
Content interface{} `json:"content"`
Role string `json:"role"`
Content any `json:"content"` // []Content or string
}

type messagePayload struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
System string `json:"system,omitempty"`
System any `json:"system,omitempty"` // string or []TextContent
MaxTokens int `json:"max_tokens,omitempty"`
StopWords []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Expand All @@ -50,27 +52,39 @@ type Tool struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema any `json:"input_schema,omitempty"`

// The fields below are used for the built-in tools:
// https://docs.anthropic.com/en/docs/build-with-claude/computer-use#understand-anthropic-defined-tools
Type string `json:"type"`
DisplayHeightPx int `json:"display_height_px,omitempty"`
DisplayWidthPx int `json:"display_width_px,omitempty"`
DisplayNumber int `json:"display_number,omitempty"`
CacheControl struct {
Type string `json:"type,omitempty"` // valid value: "ephemeral"
} `json:"cache_control,omitempty"`
}

// Content can be TextContent or ToolUseContent depending on the type.
// Content can be TextContent, ToolUseContent, or ToolResultContent depending on the type.
type Content interface {
GetType() string
}

type TextContent struct {
Type string `json:"type"`
Text string `json:"text"`
Type string `json:"type"`
Text string `json:"text"`
CacheControl map[string]string `json:"cache_control,omitempty"`
}

func (tc TextContent) GetType() string {
return tc.Type
}

type ToolUseContent struct {
Type string `json:"type"`
ID string `json:"id"`
Name string `json:"name"`
Input map[string]interface{} `json:"input"`
Type string `json:"type"`
ID string `json:"id"`
Name string `json:"name"`
Input map[string]interface{} `json:"input"`
CacheControl map[string]string `json:"cache_control,omitempty"`
}

func (tuc ToolUseContent) GetType() string {
Expand All @@ -80,7 +94,47 @@ func (tuc ToolUseContent) GetType() string {
type ToolResultContent struct {
Type string `json:"type"`
ToolUseID string `json:"tool_use_id"`
Content string `json:"content"`

// The content of the message.
// This field is mutually exclusive with MultiContent.
Content string `json:"-"`

MultiContent []llms.ToolResultContentPart `json:"-"`
CacheControl map[string]string `json:"cache_control,omitempty"`
}

// json marshal ToolResultContent such that either Content or MultiContent is set, but not both.
// the json key for both is "content"
func (trc ToolResultContent) MarshalJSON() ([]byte, error) {
if trc.Content != "" && len(trc.MultiContent) > 0 {
return nil, fmt.Errorf("both Content and MultiContent cannot be set in ToolResultContents")
}

type alias ToolResultContent

if len(trc.MultiContent) > 0 {
result, err := json.Marshal(struct {
alias
Content []llms.ToolResultContentPart `json:"content"`
}{
alias: (alias)(trc),
Content: trc.MultiContent,
})
if err != nil {
return nil, fmt.Errorf("marshal multi content: %w", err)
}
// fmt.Printf("ToolResultContent json: %s\n", string(result))

return result, nil
}

return json.Marshal(struct {
alias
Content string `json:"content"`
}{
alias: (alias)(trc),
Content: trc.Content,
})
}

func (trc ToolResultContent) GetType() string {
Expand All @@ -96,8 +150,10 @@ type MessageResponsePayload struct {
StopSequence string `json:"stop_sequence"`
Type string `json:"type"`
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
CacheReadInputTokens int `json:"cache_read_input_tokens"`
} `json:"usage"`
}

Expand Down Expand Up @@ -297,6 +353,7 @@ func handleMessageStartEvent(event map[string]interface{}, response MessageRespo
response.Role = getString(message, "role")
response.Type = getString(message, "type")
response.Usage.InputTokens = int(inputTokens)
// TODO: handle cached token usage

return response, nil
}
Expand Down
Loading

0 comments on commit 323d04c

Please sign in to comment.