Skip to content

[ML] Append all data to Chat Completion buffer #127658

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/127658.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 127658
summary: Append all data to Chat Completion buffer
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
import org.elasticsearch.common.xcontent.ChunkedToXContentObject;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xpack.core.inference.DequeUtils;

import java.io.IOException;
import java.util.Collections;
import java.util.Deque;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Flow;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.atomic.AtomicBoolean;

import static org.elasticsearch.common.xcontent.ChunkedToXContentHelper.chunk;
import static org.elasticsearch.xpack.core.inference.DequeUtils.dequeEquals;
Expand All @@ -32,9 +35,7 @@
/**
* Chat Completion results that only contain a Flow.Publisher.
*/
public record StreamingUnifiedChatCompletionResults(Flow.Publisher<? extends InferenceServiceResults.Result> publisher)
implements
InferenceServiceResults {
public record StreamingUnifiedChatCompletionResults(Flow.Publisher<Results> publisher) implements InferenceServiceResults {

public static final String NAME = "chat_completion_chunk";
public static final String MODEL_FIELD = "model";
Expand All @@ -57,6 +58,63 @@ public record StreamingUnifiedChatCompletionResults(Flow.Publisher<? extends Inf
public static final String PROMPT_TOKENS_FIELD = "prompt_tokens";
public static final String TYPE_FIELD = "type";

/**
* OpenAI Spec only returns one result at a time, and Chat Completion adheres to that spec as much as possible.
* So we will insert a buffer in between the upstream data and the downstream client so that we only send one request at a time.
*/
public StreamingUnifiedChatCompletionResults(Flow.Publisher<Results> publisher) {
Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> buffer = new LinkedBlockingDeque<>();
AtomicBoolean onComplete = new AtomicBoolean();
this.publisher = downstream -> {
publisher.subscribe(new Flow.Subscriber<>() {
@Override
public void onSubscribe(Flow.Subscription subscription) {
downstream.onSubscribe(new Flow.Subscription() {
@Override
public void request(long n) {
var nextItem = buffer.poll();
if (nextItem != null) {
downstream.onNext(new Results(DequeUtils.of(nextItem)));
} else if (onComplete.get()) {
downstream.onComplete();
} else {
subscription.request(n);
}
}

@Override
public void cancel() {
subscription.cancel();
}
});
}

@Override
public void onNext(Results item) {
var chunks = item.chunks();
var firstItem = chunks.poll();
chunks.forEach(buffer::offer);
downstream.onNext(new Results(DequeUtils.of(firstItem)));
}

@Override
public void onError(Throwable throwable) {
downstream.onError(throwable);
}

@Override
public void onComplete() {
// only complete if the buffer is empty, so that the client has a chance to drain the buffer
if (onComplete.compareAndSet(false, true)) {
if (buffer.isEmpty()) {
downstream.onComplete();
}
}
}
});
};
}

@Override
public boolean isStreaming() {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,17 @@
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.List;
import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

public class StreamingUnifiedChatCompletionResultsTests extends AbstractWireSerializingTestCase<
StreamingUnifiedChatCompletionResults.Results> {

Expand Down Expand Up @@ -198,6 +207,66 @@ public void testToolCallToXContentChunked() throws IOException {
assertEquals(expected.replaceAll("\\s+", ""), Strings.toString(builder.prettyPrint()).trim());
}

public void testBufferedPublishing() {
var results = new ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>();
results.offer(randomChatCompletionChunk());
results.offer(randomChatCompletionChunk());
var completed = new AtomicBoolean();
var streamingResults = new StreamingUnifiedChatCompletionResults(downstream -> {
downstream.onSubscribe(new Flow.Subscription() {
@Override
public void request(long n) {
if (completed.compareAndSet(false, true)) {
downstream.onNext(new StreamingUnifiedChatCompletionResults.Results(results));
} else {
downstream.onComplete();
}
}

@Override
public void cancel() {
fail("Cancel should never be called.");
}
});
});

AtomicInteger counter = new AtomicInteger(0);
AtomicReference<Flow.Subscription> upstream = new AtomicReference<>(null);
Flow.Subscriber<StreamingUnifiedChatCompletionResults.Results> subscriber = spy(
new Flow.Subscriber<StreamingUnifiedChatCompletionResults.Results>() {
@Override
public void onSubscribe(Flow.Subscription subscription) {
if (upstream.compareAndSet(null, subscription) == false) {
fail("Upstream already set?!");
}
subscription.request(1);
}

@Override
public void onNext(StreamingUnifiedChatCompletionResults.Results item) {
assertNotNull(item);
counter.incrementAndGet();
var sub = upstream.get();
if (sub != null) {
sub.request(1);
} else {
fail("Upstream not yet set?!");
}
}

@Override
public void onError(Throwable throwable) {
fail(throwable);
}

@Override
public void onComplete() {}
}
);
streamingResults.publisher().subscribe(subscriber);
verify(subscriber, times(2)).onNext(any());
}

@Override
protected Writeable.Reader<StreamingUnifiedChatCompletionResults.Results> instanceReader() {
return StreamingUnifiedChatCompletionResults.Results::new;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.inference.DequeUtils;
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
Expand Down Expand Up @@ -256,37 +257,24 @@ public void cancel() {}
"object": "chat.completion.chunk"
}
*/
private InferenceServiceResults.Result unifiedCompletionChunk(String delta) {
return new InferenceServiceResults.Result() {
@Override
public String getWriteableName() {
return "test_unifiedCompletionChunk";
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(delta);
}

@Override
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
return ChunkedToXContentHelper.chunk(
(b, p) -> b.startObject()
.field("id", "id")
.startArray("choices")
.startObject()
.startObject("delta")
.field("content", delta)
.endObject()
.field("index", 0)
.endObject()
.endArray()
.field("model", "gpt-4o-2024-08-06")
.field("object", "chat.completion.chunk")
.endObject()
);
}
};
private StreamingUnifiedChatCompletionResults.Results unifiedCompletionChunk(String delta) {
return new StreamingUnifiedChatCompletionResults.Results(
DequeUtils.of(
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(
"id",
List.of(
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(delta, null, null, null),
null,
0
)
),
"gpt-4o-2024-08-06",
"chat.completion.chunk",
null
)
)
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import java.util.Deque;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.function.BiFunction;

import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
Expand Down Expand Up @@ -60,21 +59,11 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<
public static final String TOTAL_TOKENS_FIELD = "total_tokens";

private final BiFunction<String, Exception, Exception> errorParser;
private final Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> buffer = new LinkedBlockingDeque<>();

public OpenAiUnifiedStreamingProcessor(BiFunction<String, Exception, Exception> errorParser) {
this.errorParser = errorParser;
}

@Override
protected void upstreamRequest(long n) {
if (buffer.isEmpty()) {
super.upstreamRequest(n);
} else {
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(singleItem(buffer.poll())));
}
}

@Override
protected void next(Deque<ServerSentEvent> item) throws Exception {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
Expand All @@ -96,15 +85,8 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {

if (results.isEmpty()) {
upstream().request(1);
} else if (results.size() == 1) {
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results));
} else {
// results > 1, but openai spec only wants 1 chunk per SSE event
var firstItem = singleItem(results.poll());
while (results.isEmpty() == false) {
buffer.offer(results.poll());
}
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(firstItem));
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results));
}
}

Expand Down Expand Up @@ -297,12 +279,4 @@ public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage pa
}
}
}

private Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> singleItem(
StreamingUnifiedChatCompletionResults.ChatCompletionChunk result
) {
var deque = new ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>(1);
deque.offer(result);
return deque;
}
}
Loading