Skip to content

Commit 6d289c6

Browse files
authored
Support loading prompt from yml file (#44)
2 parents 290f4d6 + f74c60a commit 6d289c6

File tree

3 files changed

+144
-0
lines changed

3 files changed

+144
-0
lines changed

cmd/run/run.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"github.com/github/gh-models/pkg/util"
2222
"github.com/spf13/cobra"
2323
"github.com/spf13/pflag"
24+
"gopkg.in/yaml.v3"
2425
)
2526

2627
// ModelParameters represents the parameters that can be set for a model run.
@@ -188,6 +189,22 @@ func isPipe(r io.Reader) bool {
188189
return false
189190
}
190191

192+
// promptFile mirrors the format of .prompt.yml
193+
type promptFile struct {
194+
Name string `yaml:"name"`
195+
Description string `yaml:"description"`
196+
Model string `yaml:"model"`
197+
ModelParameters struct {
198+
MaxTokens *int `yaml:"maxTokens"`
199+
Temperature *float64 `yaml:"temperature"`
200+
TopP *float64 `yaml:"topP"`
201+
} `yaml:"modelParameters"`
202+
Messages []struct {
203+
Role string `yaml:"role"`
204+
Content string `yaml:"content"`
205+
} `yaml:"messages"`
206+
}
207+
191208
// NewRunCommand returns a new gh command for running a model.
192209
func NewRunCommand(cfg *command.Config) *cobra.Command {
193210
cmd := &cobra.Command{
@@ -208,6 +225,24 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
208225
Example: "gh models run openai/gpt-4o-mini \"how many types of hyena are there?\"",
209226
Args: cobra.ArbitraryArgs,
210227
RunE: func(cmd *cobra.Command, args []string) error {
228+
filePath, _ := cmd.Flags().GetString("file")
229+
var pf *promptFile
230+
if filePath != "" {
231+
b, err := os.ReadFile(filePath)
232+
if err != nil {
233+
return err
234+
}
235+
p := promptFile{}
236+
if err := yaml.Unmarshal(b, &p); err != nil {
237+
return err
238+
}
239+
pf = &p
240+
// Inject model name as the first positional arg if user didn't supply one
241+
if pf.Model != "" && len(args) == 0 {
242+
args = append([]string{pf.Model}, args...)
243+
}
244+
}
245+
211246
cmdHandler := newRunCommandHandler(cmd, cfg, args)
212247
if cmdHandler == nil {
213248
return nil
@@ -248,12 +283,36 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
248283
systemPrompt: systemPrompt,
249284
}
250285

286+
// preload conversation & parameters from YAML
287+
if pf != nil {
288+
for _, m := range pf.Messages {
289+
switch strings.ToLower(m.Role) {
290+
case "system":
291+
if conversation.systemPrompt == "" {
292+
conversation.systemPrompt = m.Content
293+
} else {
294+
conversation.AddMessage(azuremodels.ChatMessageRoleSystem, m.Content)
295+
}
296+
case "user":
297+
conversation.AddMessage(azuremodels.ChatMessageRoleUser, m.Content)
298+
case "assistant":
299+
conversation.AddMessage(azuremodels.ChatMessageRoleAssistant, m.Content)
300+
}
301+
}
302+
}
303+
251304
mp := ModelParameters{}
252305
err = mp.PopulateFromFlags(cmd.Flags())
253306
if err != nil {
254307
return err
255308
}
256309

310+
if pf != nil {
311+
mp.maxTokens = pf.ModelParameters.MaxTokens
312+
mp.temperature = pf.ModelParameters.Temperature
313+
mp.topP = pf.ModelParameters.TopP
314+
}
315+
257316
for {
258317
prompt := ""
259318
if initialPrompt != "" {
@@ -369,6 +428,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
369428
},
370429
}
371430

431+
cmd.Flags().String("file", "", "Path to a .prompt.yml file.")
372432
cmd.Flags().String("max-tokens", "", "Limit the maximum tokens for the model response.")
373433
cmd.Flags().String("temperature", "", "Controls randomness in the response, use lower to be more deterministic.")
374434
cmd.Flags().String("top-p", "", "Controls text diversity by selecting the most probable words until a set probability is reached.")

cmd/run/run_test.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package run
33
import (
44
"bytes"
55
"context"
6+
"os"
67
"regexp"
78
"testing"
89

@@ -80,4 +81,73 @@ func TestRun(t *testing.T) {
8081
require.Regexp(t, regexp.MustCompile(`--top-p string\s+Controls text diversity by selecting the most probable words until a set probability is reached\.`), output)
8182
require.Empty(t, errBuf.String())
8283
})
84+
85+
t.Run("--file pre-loads YAML from file", func(t *testing.T) {
86+
const yamlBody = `
87+
name: Text Summarizer
88+
description: Summarizes input text concisely
89+
model: openai/test-model
90+
modelParameters:
91+
temperature: 0.5
92+
messages:
93+
- role: system
94+
content: You are a text summarizer.
95+
- role: user
96+
content: Hello there!
97+
`
98+
tmp, err := os.CreateTemp(t.TempDir(), "*.prompt.yml")
99+
require.NoError(t, err)
100+
_, err = tmp.WriteString(yamlBody)
101+
require.NoError(t, err)
102+
require.NoError(t, tmp.Close())
103+
104+
client := azuremodels.NewMockClient()
105+
modelSummary := &azuremodels.ModelSummary{
106+
Name: "test-model",
107+
Publisher: "openai",
108+
Task: "chat-completion",
109+
}
110+
client.MockListModels = func(ctx context.Context) ([]*azuremodels.ModelSummary, error) {
111+
return []*azuremodels.ModelSummary{modelSummary}, nil
112+
}
113+
114+
var capturedReq azuremodels.ChatCompletionOptions
115+
reply := "Summary - foo"
116+
chatCompletion := azuremodels.ChatCompletion{
117+
Choices: []azuremodels.ChatChoice{{
118+
Message: &azuremodels.ChatChoiceMessage{
119+
Content: util.Ptr(reply),
120+
Role: util.Ptr(string(azuremodels.ChatMessageRoleAssistant)),
121+
},
122+
}},
123+
}
124+
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) {
125+
capturedReq = opt
126+
return &azuremodels.ChatCompletionResponse{
127+
Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{chatCompletion}),
128+
}, nil
129+
}
130+
131+
out := new(bytes.Buffer)
132+
cfg := command.NewConfig(out, out, client, true, 100)
133+
runCmd := NewRunCommand(cfg)
134+
runCmd.SetArgs([]string{
135+
"--file", tmp.Name(),
136+
azuremodels.FormatIdentifier("openai", "test-model"),
137+
"foo?",
138+
})
139+
140+
_, err = runCmd.ExecuteC()
141+
require.NoError(t, err)
142+
143+
require.Equal(t, 3, len(capturedReq.Messages))
144+
require.Equal(t, "You are a text summarizer.", *capturedReq.Messages[0].Content)
145+
require.Equal(t, "Hello there!", *capturedReq.Messages[1].Content)
146+
require.Equal(t, "foo?", *capturedReq.Messages[2].Content)
147+
148+
require.NotNil(t, capturedReq.Temperature)
149+
require.Equal(t, 0.5, *capturedReq.Temperature)
150+
151+
require.Contains(t, out.String(), reply) // response streamed to output
152+
})
83153
}

s.prompt.yml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
name: Text Summarizer
2+
description: Summarizes input text concisely
3+
model: openai/gpt-4o-mini
4+
modelParameters:
5+
temperature: 0.5
6+
messages:
7+
- role: system
8+
content: You are a text summarizer. Your only job is to summarize text given to you.
9+
- role: user
10+
content: |
11+
Summarize the given text, beginning with "Summary -":
12+
<text>
13+
{{input}}
14+
</text>

0 commit comments

Comments
 (0)