Skip to content

Commit 159cff6

Browse files
authored
Merge pull request #17 from maximhq/10-09-fix_response_output_struct_added_to_get_prompt_version
feat: enhance MessagePayload to support both request and result types
2 parents f5e6f06 + d5067ee commit 159cff6

File tree

2 files changed

+473
-13
lines changed

2 files changed

+473
-13
lines changed

apis/prompt.go

Lines changed: 122 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,120 @@ type Message struct {
5252
OriginalType string `json:"originalType"`
5353
}
5454

55-
// MessagePayload contains the role and content of a message
55+
// MessagePayload holds either a request or result payload
5656
type MessagePayload struct {
57-
Role string `json:"role"`
58-
Content MessagePayloadContent `json:"content"`
57+
RequestPayload *ChoiceMessage
58+
ResultPayload *CompletionResultPayload
59+
}
60+
61+
// UnmarshalJSON unmarshals the MessagePayload from JSON
62+
func (m *MessagePayload) UnmarshalJSON(data []byte) error {
63+
// Try to unmarshal as CompletionRequestPayload first
64+
var reqPayload ChoiceMessage
65+
if err := json.Unmarshal(data, &reqPayload); err == nil && reqPayload.Role != "" {
66+
m.RequestPayload = &reqPayload
67+
return nil
68+
}
69+
70+
// Try to unmarshal as CompletionResultPayload
71+
var resPayload CompletionResultPayload
72+
if err := json.Unmarshal(data, &resPayload); err == nil {
73+
m.ResultPayload = &resPayload
74+
return nil
75+
}
76+
77+
return fmt.Errorf("failed to unmarshal MessagePayload")
78+
}
79+
80+
// MarshalJSON marshals the MessagePayload to JSON
81+
func (m *MessagePayload) MarshalJSON() ([]byte, error) {
82+
if m.RequestPayload != nil {
83+
return json.Marshal(m.RequestPayload)
84+
}
85+
return json.Marshal(m.ResultPayload)
86+
}
87+
88+
// CompletionResultPayload contains the completion result information
89+
type CompletionResultPayload struct {
90+
ID string `json:"id"`
91+
Cost Cost `json:"cost"`
92+
Model string `json:"model"`
93+
Trace Trace `json:"trace"`
94+
Usage Usage `json:"usage"`
95+
Choices []Choice `json:"choices"`
96+
Provider string `json:"provider"`
97+
ModelParams map[string]interface{} `json:"modelParams"`
98+
VariableBoundRetrievals map[string]interface{} `json:"variableBoundRetrievals"`
99+
}
100+
101+
// Cost represents token cost information
102+
type Cost struct {
103+
Input float64 `json:"input"`
104+
Total float64 `json:"total"`
105+
Output float64 `json:"output"`
106+
}
107+
108+
// Trace contains input/output trace information
109+
type Trace struct {
110+
Input TraceInput `json:"input"`
111+
Output TraceOutput `json:"output"`
112+
}
113+
114+
// TraceInput contains the input messages for the trace
115+
type TraceInput struct {
116+
Messages []ChoiceMessage `json:"messages"`
117+
}
118+
119+
// TraceOutput contains the output from the completion
120+
type TraceOutput struct {
121+
ID string `json:"id"`
122+
Model string `json:"model"`
123+
Usage Usage `json:"usage"`
124+
Object string `json:"object"`
125+
Choices []Choice `json:"choices"`
126+
Created int64 `json:"created"`
127+
ServiceTier string `json:"service_tier"`
128+
SystemFingerprint string `json:"system_fingerprint"`
129+
}
130+
131+
// Usage contains token usage information
132+
type Usage struct {
133+
Latency float64 `json:"latency,omitempty"`
134+
TotalTokens int `json:"total_tokens"`
135+
PromptTokens int `json:"prompt_tokens"`
136+
CompletionTokens int `json:"completion_tokens"`
137+
PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details,omitempty"`
138+
CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"`
139+
}
140+
141+
// PromptTokensDetails contains details about prompt tokens
142+
type PromptTokensDetails struct {
143+
AudioTokens int `json:"audio_tokens"`
144+
CachedTokens int `json:"cached_tokens"`
145+
}
146+
147+
// CompletionTokensDetails contains details about completion tokens
148+
type CompletionTokensDetails struct {
149+
AudioTokens int `json:"audio_tokens"`
150+
ReasoningTokens int `json:"reasoning_tokens"`
151+
AcceptedPredictionTokens int `json:"accepted_prediction_tokens"`
152+
RejectedPredictionTokens int `json:"rejected_prediction_tokens"`
153+
}
154+
155+
// Choice represents a completion choice
156+
type Choice struct {
157+
Index int `json:"index"`
158+
Message ChoiceMessage `json:"message"`
159+
FinishReason string `json:"finish_reason"`
160+
Logprobs interface{} `json:"logprobs"`
161+
}
162+
163+
// ChoiceMessage represents a message in a choice
164+
type ChoiceMessage struct {
165+
Role string `json:"role"`
166+
Content MessagePayloadContent `json:"content"`
167+
Refusal *string `json:"refusal"`
168+
Annotations []interface{} `json:"annotations,omitempty"`
59169
}
60170

61171
type MessagePayloadContent struct {
@@ -68,6 +178,7 @@ type MessagePayloadContentBlock struct {
68178
Text string `json:"text"`
69179
}
70180

181+
// UnmarshalJSON unmarshals the MessagePayloadContent from JSON
71182
func (m *MessagePayloadContent) UnmarshalJSON(data []byte) error {
72183
var messageStr string
73184
if err := json.Unmarshal(data, &messageStr); err == nil {
@@ -82,6 +193,14 @@ func (m *MessagePayloadContent) UnmarshalJSON(data []byte) error {
82193
return fmt.Errorf("failed to unmarshal MessagePayloadContent")
83194
}
84195

196+
// MarshalJSON marshals the MessagePayloadContent to JSON
197+
func (m *MessagePayloadContent) MarshalJSON() ([]byte, error) {
198+
if m.MessagePayloadContentStr != nil {
199+
return json.Marshal(m.MessagePayloadContentStr)
200+
}
201+
return json.Marshal(m.MessagePayloadContentArray)
202+
}
203+
85204
// ModelParameters contains the model configuration parameters
86205
type ModelParameters struct {
87206
N int `json:"n"`

0 commit comments

Comments
 (0)