1919#include < QtLogging>
2020
2121#include < iostream>
22- #include < string>
23- #include < type_traits>
2422#include < utility>
2523
2624using namespace Qt ::Literals::StringLiterals;
@@ -207,26 +205,29 @@ void Server::start()
207205
208206QHttpServerResponse Server::handleCompletionRequest (const QHttpServerRequest &request, bool isChat)
209207{
210- // We've been asked to do a completion...
208+ // Parse JSON request
211209 QJsonParseError err;
212210 const QJsonDocument document = QJsonDocument::fromJson (request.body (), &err);
213211 if (err.error || !document.isObject ()) {
214- std::cerr << " ERROR: invalid json in completions body" << std::endl;
212+ std::cerr << " ERROR: invalid JSON in completions body" << std::endl;
215213 return QHttpServerResponse (QHttpServerResponder::StatusCode::NoContent);
216214 }
215+
217216#if defined(DEBUG)
218217 printf (" /v1/completions %s\n " , qPrintable (document.toJson (QJsonDocument::Indented)));
219218 fflush (stdout);
220219#endif
220+
221221 const QJsonObject body = document.object ();
222- if (!body.contains (" model" )) { // required
223- std::cerr << " ERROR: completions contains no model" << std::endl;
222+ if (!body.contains (" model" )) {
223+ std::cerr << " ERROR: completions contain no model" << std::endl;
224224 return QHttpServerResponse (QHttpServerResponder::StatusCode::NoContent);
225225 }
226+
226227 QJsonArray messages;
227228 if (isChat) {
228229 if (!body.contains (" messages" )) {
229- std::cerr << " ERROR: chat completions contains no messages" << std::endl;
230+ std::cerr << " ERROR: chat completions contain no messages" << std::endl;
230231 return QHttpServerResponse (QHttpServerResponder::StatusCode::NoContent);
231232 }
232233 messages = body[" messages" ].toArray ();
@@ -236,16 +237,12 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re
236237 ModelInfo modelInfo = ModelList::globalInstance ()->defaultModelInfo ();
237238 const QList<ModelInfo> modelList = ModelList::globalInstance ()->selectableModelList ();
238239 for (const ModelInfo &info : modelList) {
239- Q_ASSERT (info.installed );
240- if (!info.installed )
241- continue ;
242- if (modelRequested == info.name () || modelRequested == info.filename ()) {
240+ if (info.installed && (modelRequested == info.name () || modelRequested == info.filename ())) {
243241 modelInfo = info;
244242 break ;
245243 }
246244 }
247245
248- // We only support one prompt for now
249246 QList<QString> prompts;
250247 if (body.contains (" prompt" )) {
251248 QJsonValue promptValue = body[" prompt" ];
@@ -256,102 +253,32 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re
256253 for (const QJsonValue &v : array)
257254 prompts.append (v.toString ());
258255 }
259- } else
256+ } else {
260257 prompts.append (" " );
261-
262- int max_tokens = 16 ;
263- if (body.contains (" max_tokens" ))
264- max_tokens = body[" max_tokens" ].toInt ();
265-
266- float temperature = 1 .f ;
267- if (body.contains (" temperature" ))
268- temperature = body[" temperature" ].toDouble ();
269-
270- float top_p = 1 .f ;
271- if (body.contains (" top_p" ))
272- top_p = body[" top_p" ].toDouble ();
273-
274- float min_p = 0 .f ;
275- if (body.contains (" min_p" ))
276- min_p = body[" min_p" ].toDouble ();
277-
278- int n = 1 ;
279- if (body.contains (" n" ))
280- n = body[" n" ].toInt ();
281-
282- int logprobs = -1 ; // supposed to be null by default??
283- if (body.contains (" logprobs" ))
284- logprobs = body[" logprobs" ].toInt ();
285-
286- bool echo = false ;
287- if (body.contains (" echo" ))
288- echo = body[" echo" ].toBool ();
289-
290- // We currently don't support any of the following...
291- #if 0
292- // FIXME: Need configurable reverse prompts
293- QList<QString> stop;
294- if (body.contains("stop")) {
295- QJsonValue stopValue = body["stop"];
296- if (stopValue.isString())
297- stop.append(stopValue.toString());
298- else {
299- QJsonArray array = stopValue.toArray();
300- for (QJsonValue v : array)
301- stop.append(v.toString());
302- }
303258 }
304259
305- // FIXME: QHttpServer doesn't support server-sent events
306- bool stream = false;
307- if (body.contains("stream"))
308- stream = body["stream"].toBool();
309-
310- // FIXME: What does this do?
311- QString suffix;
312- if (body.contains("suffix"))
313- suffix = body["suffix"].toString();
314-
315- // FIXME: We don't support
316- float presence_penalty = 0.f;
317- if (body.contains("presence_penalty"))
318- top_p = body["presence_penalty"].toDouble();
319-
320- // FIXME: We don't support
321- float frequency_penalty = 0.f;
322- if (body.contains("frequency_penalty"))
323- top_p = body["frequency_penalty"].toDouble();
324-
325- // FIXME: We don't support
326- int best_of = 1;
327- if (body.contains("best_of"))
328- logprobs = body["best_of"].toInt();
329-
330- // FIXME: We don't need
331- QString user;
332- if (body.contains("user"))
333- suffix = body["user"].toString();
334- #endif
260+ int max_tokens = body.value (" max_tokens" ).toInt (16 );
261+ float temperature = body.value (" temperature" ).toDouble (1.0 );
262+ float top_p = body.value (" top_p" ).toDouble (1.0 );
263+ float min_p = body.value (" min_p" ).toDouble (0.0 );
264+ int n = body.value (" n" ).toInt (1 );
265+ bool echo = body.value (" echo" ).toBool (false );
335266
336267 QString actualPrompt = prompts.first ();
337268
338- // if we're a chat completion we have messages which means we need to prepend these to the prompt
339269 if (!messages.isEmpty ()) {
340270 QList<QString> chats;
341- for (int i = 0 ; i < messages.count (); ++i) {
342- QJsonValue v = messages.at (i);
343- QString content = v.toObject ()[" content" ].toString ();
271+ for (int i = 0 ; i < messages.count (); ++i) {
272+ QString content = messages.at (i).toObject ()[" content" ].toString ();
344273 if (!content.endsWith (" \n " ) && i < messages.count () - 1 )
345274 content += " \n " ;
346275 chats.append (content);
347276 }
348277 actualPrompt.prepend (chats.join (" \n " ));
349278 }
350279
351- // adds prompt/response items to GUI
352280 emit requestServerNewPromptResponsePair (actualPrompt); // blocks
353281
354- // load the new model if necessary
355282 setShouldBeLoaded (true );
356283
357284 if (modelInfo.filename ().isEmpty ()) {
@@ -362,107 +289,83 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re
362289 return QHttpServerResponse (QHttpServerResponder::StatusCode::InternalServerError);
363290 }
364291
365- // don't remember any context
366292 resetContext ();
367293
368- const QString promptTemplate = modelInfo.promptTemplate ();
369- const float top_k = modelInfo.topK ();
370- const int n_batch = modelInfo.promptBatchSize ();
371- const float repeat_penalty = modelInfo.repeatPenalty ();
372- const int repeat_last_n = modelInfo.repeatPenaltyTokens ();
294+ QByteArray responseData;
295+ QTextStream stream (&responseData, QIODevice::WriteOnly);
296+
297+ QString randomId = " chatcmpl-" + QUuid::createUuid ().toString (QUuid::WithoutBraces).replace (" -" , " " );
373298
374- int promptTokens = 0 ;
375- int responseTokens = 0 ;
376- QList<QPair<QString, QList<ResultInfo>>> responses;
377299 for (int i = 0 ; i < n; ++i) {
378- if (!promptInternal (
379- m_collections,
380- actualPrompt,
381- promptTemplate,
382- max_tokens /* n_predict*/ ,
383- top_k,
384- top_p,
385- min_p,
386- temperature,
387- n_batch,
388- repeat_penalty,
389- repeat_last_n)) {
300+ if (!promptInternal (m_collections,
301+ actualPrompt,
302+ modelInfo.promptTemplate (),
303+ max_tokens /* n_predict*/ ,
304+ modelInfo.topK (),
305+ top_p,
306+ min_p,
307+ temperature,
308+ modelInfo.promptBatchSize (),
309+ modelInfo.repeatPenalty (),
310+ modelInfo.repeatPenaltyTokens ())) {
390311
391312 std::cerr << " ERROR: couldn't prompt model " << modelInfo.name ().toStdString () << std::endl;
392313 return QHttpServerResponse (QHttpServerResponder::StatusCode::InternalServerError);
393314 }
394- QString echoedPrompt = actualPrompt;
395- if (!echoedPrompt.endsWith (" \n " ))
396- echoedPrompt += " \n " ;
397- responses.append (qMakePair ((echo ? u" %1\n " _s.arg (actualPrompt) : QString ()) + response (), m_databaseResults));
398- if (!promptTokens)
399- promptTokens += m_promptTokens;
400- responseTokens += m_promptResponseTokens - m_promptTokens;
401- if (i != n - 1 )
402- resetResponse ();
403- }
404315
405- QJsonObject responseObject;
406- responseObject.insert (" id" , " foobarbaz" );
407- responseObject.insert (" object" , " text_completion" );
408- responseObject.insert (" created" , QDateTime::currentSecsSinceEpoch ());
409- responseObject.insert (" model" , modelInfo.name ());
316+ QString result = (echo ? u" %1\n " _s.arg (actualPrompt) : QString ()) + response ();
410317
411- QJsonArray choices;
318+ for (const QString &token : result.split (' ' )) {
319+ QJsonObject delta;
320+ delta.insert (" content" , token + " " );
412321
413- if (isChat) {
414- int index = 0 ;
415- for (const auto &r : responses) {
416- QString result = r.first ;
417- QList<ResultInfo> infos = r.second ;
418322 QJsonObject choice;
419- choice.insert (" index" , index++);
420- choice.insert (" finish_reason" , responseTokens == max_tokens ? " length" : " stop" );
421- QJsonObject message;
422- message.insert (" role" , " assistant" );
423- message.insert (" content" , result);
424- choice.insert (" message" , message);
425- if (MySettings::globalInstance ()->localDocsShowReferences ()) {
426- QJsonArray references;
427- for (const auto &ref : infos)
428- references.append (resultToJson (ref));
429- choice.insert (" references" , references);
430- }
431- choices.append (choice);
432- }
433- } else {
434- int index = 0 ;
435- for (const auto &r : responses) {
436- QString result = r.first ;
437- QList<ResultInfo> infos = r.second ;
438- QJsonObject choice;
439- choice.insert (" text" , result);
440- choice.insert (" index" , index++);
441- choice.insert (" logprobs" , QJsonValue::Null); // We don't support
442- choice.insert (" finish_reason" , responseTokens == max_tokens ? " length" : " stop" );
443- if (MySettings::globalInstance ()->localDocsShowReferences ()) {
444- QJsonArray references;
445- for (const auto &ref : infos)
446- references.append (resultToJson (ref));
447- choice.insert (" references" , references);
448- }
449- choices.append (choice);
323+ choice.insert (" index" , i);
324+ choice.insert (" delta" , delta);
325+
326+ QJsonObject responseChunk;
327+ responseChunk.insert (" id" , randomId);
328+ responseChunk.insert (" object" , " chat.completion.chunk" );
329+ responseChunk.insert (" created" , QDateTime::currentSecsSinceEpoch ());
330+ responseChunk.insert (" model" , modelInfo.name ());
331+ responseChunk.insert (" choices" , QJsonArray{choice});
332+
333+ stream << " data: " << QJsonDocument (responseChunk).toJson (QJsonDocument::Compact) << " \n\n " ;
334+ stream.flush ();
450335 }
336+
337+ if (i != n - 1 )
338+ resetResponse ();
451339 }
452340
453- responseObject.insert (" choices" , choices);
341+ // Final empty delta to signify the end of the stream
342+ QJsonObject delta;
343+ delta.insert (" content" , QJsonValue::Null);
454344
455- QJsonObject usage;
456- usage.insert (" prompt_tokens" , int (promptTokens));
457- usage.insert (" completion_tokens" , int (responseTokens));
458- usage.insert (" total_tokens" , int (promptTokens + responseTokens));
459- responseObject.insert (" usage" , usage);
345+ QJsonObject choice;
346+ choice.insert (" index" , 0 );
347+ choice.insert (" delta" , delta);
348+ choice.insert (" finish_reason" , " stop" );
460349
461- #if defined(DEBUG)
462- QJsonDocument newDoc (responseObject);
463- printf (" /v1/completions %s\n " , qPrintable (newDoc.toJson (QJsonDocument::Indented)));
464- fflush (stdout);
465- #endif
350+ QJsonObject finalChunk;
351+ finalChunk.insert (" id" , randomId);
352+ finalChunk.insert (" object" , " chat.completion.chunk" );
353+ finalChunk.insert (" created" , QDateTime::currentSecsSinceEpoch ());
354+ finalChunk.insert (" model" , modelInfo.name ());
355+ finalChunk.insert (" choices" , QJsonArray{choice});
356+
357+ stream << " data: " << QJsonDocument (finalChunk).toJson (QJsonDocument::Compact) << " \n\n " ;
358+ stream << " data: [DONE]\n\n " ;
359+ stream.flush ();
360+
361+ // Log the entire response data
362+ qDebug () << " Full SSE Response:\n " << responseData;
363+
364+ // Create the response
365+ QHttpServerResponse response (responseData, QHttpServerResponder::StatusCode::Ok);
366+ response.setHeader (" Content-Type" , " text/event-stream" );
367+ response.setHeader (" Cache-Control" , " no-cache" );
368+ response.setHeader (" Connection" , " keep-alive" );
466369
467- return QHttpServerResponse (responseObject) ;
370+ return response ;
468371}
0 commit comments