diff --git a/openai/init.moon b/openai/init.moon index 072242d..60fc3e4 100644 --- a/openai/init.moon +++ b/openai/init.moon @@ -82,19 +82,6 @@ parse_chat_response = types.partial { -- } -parse_completion_chunk = types.partial { - object: "chat.completion.chunk" - -- not sure of the whole range of chunks, so for now we strictly parse an append - choices: types.shape { - types.partial { - delta: types.shape { - "content": types.string\tag "content" - } - index: types.number\tag "index" - } - } -} - -- lpeg pattern to read a json data block from the front of a string, returns -- the json blob and the rest of the string if it could parse one consume_json_head = do @@ -183,20 +170,8 @@ class ChatSession return nil, err_msg, response - -- if we are streaming we need to pase the entire fragmented response if stream_callback - assert type(response) == "string", - "Expected string response from streaming output" - - parts = {} - f = @client\create_stream_filter (c) -> - table.insert parts, c.content - - f response - message = { - role: "assistant" - content: table.concat parts - } + message = response.choices[1].message if append_response @append_message message @@ -250,7 +225,7 @@ class OpenAI break accumulation_buffer = rest - if chunk = parse_completion_chunk cjson.decode json_blob + if chunk = cjson.decode json_blob chunk_callback chunk ... @@ -273,10 +248,7 @@ class OpenAI for k,v in pairs opts payload[k] = v - stream_filter = if payload.stream - @create_stream_filter chunk_callback - - @_request "POST", "/chat/completions", payload, nil, stream_filter + @_request "POST", "/chat/completions", payload, nil, if payload.stream then chunk_callback else nil -- call /completions -- opts: additional parameters as described in https://platform.openai.com/docs/api-reference/completions @@ -362,7 +334,7 @@ class OpenAI image_generation: (params) => @_request "POST", "/images/generations", params - _request: (method, path, payload, more_headers, stream_fn) => + _request: (method, path, payload, more_headers, chunk_callback) => assert path, "missing path" assert method, "missing method" @@ -392,7 +364,17 @@ class OpenAI sink = ltn12.sink.table out - if stream_fn + parts = {} + if chunk_callback + stream_fn = @create_stream_filter (c) -> + c0 = c.choices[1] + part = parts[c0.index] or {} + part.data = c + part.finish_reason = c0.finish_reason + parts[c0.index] = part + if c0.delta.content and c0.delta.content ~= cjson.null + table.insert part, c0.delta.content + chunk_callback(c) sink = ltn12.sink.chain stream_fn, sink _, status, out_headers = @get_http!.request { @@ -403,6 +385,23 @@ class OpenAI :headers } + if status == 200 and chunk_callback + choices = {} + index = 0 + local data + while parts[index] + part = parts[index] + data = part.data + message = { + role: "assistant" + content: table.concat part + } + choices[index+1] = { :index, :message, finish_reason: part.finish_reason } + index += 1 + data.object = "chat.completion" + data.choices = choices + return status, data, out_headers + response = table.concat out pcall -> response = cjson.decode response status, response, out_headers