Skip to content

Commit 2bdaa9e

Browse files
committed
added support for chat.completion to return sse response
1 parent c9dda3d commit 2bdaa9e

File tree

1 file changed

+79
-176
lines changed

1 file changed

+79
-176
lines changed

gpt4all-chat/server.cpp

Lines changed: 79 additions & 176 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
#include <QtLogging>
2020

2121
#include <iostream>
22-
#include <string>
23-
#include <type_traits>
2422
#include <utility>
2523

2624
using namespace Qt::Literals::StringLiterals;
@@ -207,26 +205,29 @@ void Server::start()
207205

208206
QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &request, bool isChat)
209207
{
210-
// We've been asked to do a completion...
208+
// Parse JSON request
211209
QJsonParseError err;
212210
const QJsonDocument document = QJsonDocument::fromJson(request.body(), &err);
213211
if (err.error || !document.isObject()) {
214-
std::cerr << "ERROR: invalid json in completions body" << std::endl;
212+
std::cerr << "ERROR: invalid JSON in completions body" << std::endl;
215213
return QHttpServerResponse(QHttpServerResponder::StatusCode::NoContent);
216214
}
215+
217216
#if defined(DEBUG)
218217
printf("/v1/completions %s\n", qPrintable(document.toJson(QJsonDocument::Indented)));
219218
fflush(stdout);
220219
#endif
220+
221221
const QJsonObject body = document.object();
222-
if (!body.contains("model")) { // required
223-
std::cerr << "ERROR: completions contains no model" << std::endl;
222+
if (!body.contains("model")) {
223+
std::cerr << "ERROR: completions contain no model" << std::endl;
224224
return QHttpServerResponse(QHttpServerResponder::StatusCode::NoContent);
225225
}
226+
226227
QJsonArray messages;
227228
if (isChat) {
228229
if (!body.contains("messages")) {
229-
std::cerr << "ERROR: chat completions contains no messages" << std::endl;
230+
std::cerr << "ERROR: chat completions contain no messages" << std::endl;
230231
return QHttpServerResponse(QHttpServerResponder::StatusCode::NoContent);
231232
}
232233
messages = body["messages"].toArray();
@@ -236,16 +237,12 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re
236237
ModelInfo modelInfo = ModelList::globalInstance()->defaultModelInfo();
237238
const QList<ModelInfo> modelList = ModelList::globalInstance()->selectableModelList();
238239
for (const ModelInfo &info : modelList) {
239-
Q_ASSERT(info.installed);
240-
if (!info.installed)
241-
continue;
242-
if (modelRequested == info.name() || modelRequested == info.filename()) {
240+
if (info.installed && (modelRequested == info.name() || modelRequested == info.filename())) {
243241
modelInfo = info;
244242
break;
245243
}
246244
}
247245

248-
// We only support one prompt for now
249246
QList<QString> prompts;
250247
if (body.contains("prompt")) {
251248
QJsonValue promptValue = body["prompt"];
@@ -256,102 +253,32 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re
256253
for (const QJsonValue &v : array)
257254
prompts.append(v.toString());
258255
}
259-
} else
256+
} else {
260257
prompts.append(" ");
261-
262-
int max_tokens = 16;
263-
if (body.contains("max_tokens"))
264-
max_tokens = body["max_tokens"].toInt();
265-
266-
float temperature = 1.f;
267-
if (body.contains("temperature"))
268-
temperature = body["temperature"].toDouble();
269-
270-
float top_p = 1.f;
271-
if (body.contains("top_p"))
272-
top_p = body["top_p"].toDouble();
273-
274-
float min_p = 0.f;
275-
if (body.contains("min_p"))
276-
min_p = body["min_p"].toDouble();
277-
278-
int n = 1;
279-
if (body.contains("n"))
280-
n = body["n"].toInt();
281-
282-
int logprobs = -1; // supposed to be null by default??
283-
if (body.contains("logprobs"))
284-
logprobs = body["logprobs"].toInt();
285-
286-
bool echo = false;
287-
if (body.contains("echo"))
288-
echo = body["echo"].toBool();
289-
290-
// We currently don't support any of the following...
291-
#if 0
292-
// FIXME: Need configurable reverse prompts
293-
QList<QString> stop;
294-
if (body.contains("stop")) {
295-
QJsonValue stopValue = body["stop"];
296-
if (stopValue.isString())
297-
stop.append(stopValue.toString());
298-
else {
299-
QJsonArray array = stopValue.toArray();
300-
for (QJsonValue v : array)
301-
stop.append(v.toString());
302-
}
303258
}
304259

305-
// FIXME: QHttpServer doesn't support server-sent events
306-
bool stream = false;
307-
if (body.contains("stream"))
308-
stream = body["stream"].toBool();
309-
310-
// FIXME: What does this do?
311-
QString suffix;
312-
if (body.contains("suffix"))
313-
suffix = body["suffix"].toString();
314-
315-
// FIXME: We don't support
316-
float presence_penalty = 0.f;
317-
if (body.contains("presence_penalty"))
318-
top_p = body["presence_penalty"].toDouble();
319-
320-
// FIXME: We don't support
321-
float frequency_penalty = 0.f;
322-
if (body.contains("frequency_penalty"))
323-
top_p = body["frequency_penalty"].toDouble();
324-
325-
// FIXME: We don't support
326-
int best_of = 1;
327-
if (body.contains("best_of"))
328-
logprobs = body["best_of"].toInt();
329-
330-
// FIXME: We don't need
331-
QString user;
332-
if (body.contains("user"))
333-
suffix = body["user"].toString();
334-
#endif
260+
int max_tokens = body.value("max_tokens").toInt(16);
261+
float temperature = body.value("temperature").toDouble(1.0);
262+
float top_p = body.value("top_p").toDouble(1.0);
263+
float min_p = body.value("min_p").toDouble(0.0);
264+
int n = body.value("n").toInt(1);
265+
bool echo = body.value("echo").toBool(false);
335266

336267
QString actualPrompt = prompts.first();
337268

338-
// if we're a chat completion we have messages which means we need to prepend these to the prompt
339269
if (!messages.isEmpty()) {
340270
QList<QString> chats;
341-
for (int i = 0; i < messages.count(); ++i) {
342-
QJsonValue v = messages.at(i);
343-
QString content = v.toObject()["content"].toString();
271+
for (int i = 0; i < messages.count(); ++i) {
272+
QString content = messages.at(i).toObject()["content"].toString();
344273
if (!content.endsWith("\n") && i < messages.count() - 1)
345274
content += "\n";
346275
chats.append(content);
347276
}
348277
actualPrompt.prepend(chats.join("\n"));
349278
}
350279

351-
// adds prompt/response items to GUI
352280
emit requestServerNewPromptResponsePair(actualPrompt); // blocks
353281

354-
// load the new model if necessary
355282
setShouldBeLoaded(true);
356283

357284
if (modelInfo.filename().isEmpty()) {
@@ -362,107 +289,83 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re
362289
return QHttpServerResponse(QHttpServerResponder::StatusCode::InternalServerError);
363290
}
364291

365-
// don't remember any context
366292
resetContext();
367293

368-
const QString promptTemplate = modelInfo.promptTemplate();
369-
const float top_k = modelInfo.topK();
370-
const int n_batch = modelInfo.promptBatchSize();
371-
const float repeat_penalty = modelInfo.repeatPenalty();
372-
const int repeat_last_n = modelInfo.repeatPenaltyTokens();
294+
QByteArray responseData;
295+
QTextStream stream(&responseData, QIODevice::WriteOnly);
296+
297+
QString randomId = "chatcmpl-" + QUuid::createUuid().toString(QUuid::WithoutBraces).replace("-", "");
373298

374-
int promptTokens = 0;
375-
int responseTokens = 0;
376-
QList<QPair<QString, QList<ResultInfo>>> responses;
377299
for (int i = 0; i < n; ++i) {
378-
if (!promptInternal(
379-
m_collections,
380-
actualPrompt,
381-
promptTemplate,
382-
max_tokens /*n_predict*/,
383-
top_k,
384-
top_p,
385-
min_p,
386-
temperature,
387-
n_batch,
388-
repeat_penalty,
389-
repeat_last_n)) {
300+
if (!promptInternal(m_collections,
301+
actualPrompt,
302+
modelInfo.promptTemplate(),
303+
max_tokens /*n_predict*/,
304+
modelInfo.topK(),
305+
top_p,
306+
min_p,
307+
temperature,
308+
modelInfo.promptBatchSize(),
309+
modelInfo.repeatPenalty(),
310+
modelInfo.repeatPenaltyTokens())) {
390311

391312
std::cerr << "ERROR: couldn't prompt model " << modelInfo.name().toStdString() << std::endl;
392313
return QHttpServerResponse(QHttpServerResponder::StatusCode::InternalServerError);
393314
}
394-
QString echoedPrompt = actualPrompt;
395-
if (!echoedPrompt.endsWith("\n"))
396-
echoedPrompt += "\n";
397-
responses.append(qMakePair((echo ? u"%1\n"_s.arg(actualPrompt) : QString()) + response(), m_databaseResults));
398-
if (!promptTokens)
399-
promptTokens += m_promptTokens;
400-
responseTokens += m_promptResponseTokens - m_promptTokens;
401-
if (i != n - 1)
402-
resetResponse();
403-
}
404315

405-
QJsonObject responseObject;
406-
responseObject.insert("id", "foobarbaz");
407-
responseObject.insert("object", "text_completion");
408-
responseObject.insert("created", QDateTime::currentSecsSinceEpoch());
409-
responseObject.insert("model", modelInfo.name());
316+
QString result = (echo ? u"%1\n"_s.arg(actualPrompt) : QString()) + response();
410317

411-
QJsonArray choices;
318+
for (const QString &token : result.split(' ')) {
319+
QJsonObject delta;
320+
delta.insert("content", token + " ");
412321

413-
if (isChat) {
414-
int index = 0;
415-
for (const auto &r : responses) {
416-
QString result = r.first;
417-
QList<ResultInfo> infos = r.second;
418322
QJsonObject choice;
419-
choice.insert("index", index++);
420-
choice.insert("finish_reason", responseTokens == max_tokens ? "length" : "stop");
421-
QJsonObject message;
422-
message.insert("role", "assistant");
423-
message.insert("content", result);
424-
choice.insert("message", message);
425-
if (MySettings::globalInstance()->localDocsShowReferences()) {
426-
QJsonArray references;
427-
for (const auto &ref : infos)
428-
references.append(resultToJson(ref));
429-
choice.insert("references", references);
430-
}
431-
choices.append(choice);
432-
}
433-
} else {
434-
int index = 0;
435-
for (const auto &r : responses) {
436-
QString result = r.first;
437-
QList<ResultInfo> infos = r.second;
438-
QJsonObject choice;
439-
choice.insert("text", result);
440-
choice.insert("index", index++);
441-
choice.insert("logprobs", QJsonValue::Null); // We don't support
442-
choice.insert("finish_reason", responseTokens == max_tokens ? "length" : "stop");
443-
if (MySettings::globalInstance()->localDocsShowReferences()) {
444-
QJsonArray references;
445-
for (const auto &ref : infos)
446-
references.append(resultToJson(ref));
447-
choice.insert("references", references);
448-
}
449-
choices.append(choice);
323+
choice.insert("index", i);
324+
choice.insert("delta", delta);
325+
326+
QJsonObject responseChunk;
327+
responseChunk.insert("id", randomId);
328+
responseChunk.insert("object", "chat.completion.chunk");
329+
responseChunk.insert("created", QDateTime::currentSecsSinceEpoch());
330+
responseChunk.insert("model", modelInfo.name());
331+
responseChunk.insert("choices", QJsonArray{choice});
332+
333+
stream << "data: " << QJsonDocument(responseChunk).toJson(QJsonDocument::Compact) << "\n\n";
334+
stream.flush();
450335
}
336+
337+
if (i != n - 1)
338+
resetResponse();
451339
}
452340

453-
responseObject.insert("choices", choices);
341+
// Final empty delta to signify the end of the stream
342+
QJsonObject delta;
343+
delta.insert("content", QJsonValue::Null);
454344

455-
QJsonObject usage;
456-
usage.insert("prompt_tokens", int(promptTokens));
457-
usage.insert("completion_tokens", int(responseTokens));
458-
usage.insert("total_tokens", int(promptTokens + responseTokens));
459-
responseObject.insert("usage", usage);
345+
QJsonObject choice;
346+
choice.insert("index", 0);
347+
choice.insert("delta", delta);
348+
choice.insert("finish_reason", "stop");
460349

461-
#if defined(DEBUG)
462-
QJsonDocument newDoc(responseObject);
463-
printf("/v1/completions %s\n", qPrintable(newDoc.toJson(QJsonDocument::Indented)));
464-
fflush(stdout);
465-
#endif
350+
QJsonObject finalChunk;
351+
finalChunk.insert("id", randomId);
352+
finalChunk.insert("object", "chat.completion.chunk");
353+
finalChunk.insert("created", QDateTime::currentSecsSinceEpoch());
354+
finalChunk.insert("model", modelInfo.name());
355+
finalChunk.insert("choices", QJsonArray{choice});
356+
357+
stream << "data: " << QJsonDocument(finalChunk).toJson(QJsonDocument::Compact) << "\n\n";
358+
stream << "data: [DONE]\n\n";
359+
stream.flush();
360+
361+
// Log the entire response data
362+
qDebug() << "Full SSE Response:\n" << responseData;
363+
364+
// Create the response
365+
QHttpServerResponse response(responseData, QHttpServerResponder::StatusCode::Ok);
366+
response.setHeader("Content-Type", "text/event-stream");
367+
response.setHeader("Cache-Control", "no-cache");
368+
response.setHeader("Connection", "keep-alive");
466369

467-
return QHttpServerResponse(responseObject);
370+
return response;
468371
}

0 commit comments

Comments
 (0)