Skip to content

Commit 20f7fc6

Browse files
committed
add sum function, make logging a bit better
1 parent 6b23702 commit 20f7fc6

File tree

1 file changed

+41
-14
lines changed

1 file changed

+41
-14
lines changed

examples/example5.lua

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ local two_numbers = {
3131

3232
local chat = client:new_chat_session({
3333
temperature = 0,
34-
-- model = "gpt-3.5-turbo-0613",
34+
-- model = "gpt-3.5-turbo-0613", -- In my testing, gpt-3.5 is unable to figure it out
3535
model = "gpt-4-0613",
3636
messages = {
3737
{
@@ -40,9 +40,23 @@ local chat = client:new_chat_session({
4040
}
4141
},
4242
functions = {
43-
{ name = "add", description = "Add two numbers together", parameters = two_numbers },
44-
{ name = "divide", description = "Divide two numbers", parameters = two_numbers },
45-
{ name = "multiply", description = "Multiply two numbers together", parameters = two_numbers },
43+
{ name = "add", description = "Add two numbers, a + b", parameters = two_numbers },
44+
{ name = "subtract", description = "Subtract two numbers, a - b", parameters = two_numbers },
45+
{ name = "divide", description = "Divide two numbers, a / b", parameters = two_numbers },
46+
{ name = "multiply", description = "Multiply two numbers together, a * b", parameters = two_numbers },
47+
48+
{
49+
name = "sum", description = "Add a list of numbers together",
50+
parameters = {
51+
type = "object",
52+
properties = {
53+
numbers = {
54+
type = "array",
55+
items = { type = "number" }
56+
}
57+
}
58+
}
59+
},
4660
{
4761
name = "sqrt", description = "Calculate square root of a number",
4862
parameters = {
@@ -62,21 +76,20 @@ function chat:send(v, ...)
6276
return chat_send(self, v, ...)
6377
end
6478

65-
66-
local one_args = types.annotate(types.string / cjson.decode * types.partial({
67-
a = types.number
68-
}))
69-
70-
local two_args = types.annotate(types.string / cjson.decode * types.partial({
79+
local two_args = types.string / cjson.decode * types.partial({
7180
a = types.number,
7281
b = types.number
73-
}))
82+
})
7483

7584
local funcs = {
7685
add = {
7786
arguments = two_args,
7887
call = function(args) return args.a + args.b end
7988
},
89+
subtract = {
90+
arguments = two_args,
91+
call = function(args) return args.a - args.b end
92+
},
8093
divide = {
8194
arguments = two_args,
8295
call = function(args) return args.a / args.b end
@@ -85,8 +98,22 @@ local funcs = {
8598
arguments = two_args,
8699
call = function(args) return args.a * args.b end
87100
},
101+
sum = {
102+
arguments = types.string / cjson.decode * types.partial({
103+
numbers = types.array(types.number)
104+
}),
105+
call = function(args)
106+
local sum = 0
107+
for _, number in ipairs(args.numbers) do
108+
sum = sum + number
109+
end
110+
return sum
111+
end
112+
},
88113
sqrt = {
89-
arguments = one_args,
114+
arguments = types.string / cjson.decode * types.partial({
115+
a = types.number
116+
}),
90117
call = function(args)
91118
return math.sqrt(args.a)
92119
end
@@ -99,7 +126,7 @@ while true do
99126
local last_message = chat:last_message()
100127

101128
for k, v in pairs(last_message) do
102-
p(k, v)
129+
p("<<", k, v)
103130
end
104131

105132
-- stop if no functions are requested
@@ -113,7 +140,7 @@ while true do
113140
if not func_handler then
114141
assert(chat:send("You called a function that is not declared: " .. func.name))
115142
else
116-
local arguments, err = func_handler.arguments:transform(func.arguments)
143+
local arguments, err = types.annotate(func_handler.arguments):transform(func.arguments)
117144
if not arguments then
118145
assert(chat:send("Invalid arguments for function " .. func.name .. ": " .. err))
119146
else

0 commit comments

Comments
 (0)