Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support llm streaming on top of arrow flight #3645

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -225,4 +225,8 @@ public String encrypt(String credential, String tenantId) {
return encryptor.encrypt(credential, tenantId);
}

public void setStreamManager(Supplier<StreamManager> streamManager) {
this.streamManager = streamManager;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,20 @@
import static software.amazon.awssdk.http.SdkHttpMethod.POST;

import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;

import org.apache.logging.log4j.Logger;
import org.opensearch.arrow.spi.StreamManager;
import org.opensearch.arrow.spi.StreamTicket;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.util.TokenBucket;
import org.opensearch.core.action.ActionListener;
Expand All @@ -25,15 +32,24 @@
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLGuard;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.annotation.ConnectorExecutor;
import org.opensearch.ml.engine.arrow.RemoteModelStreamProducer;
import org.opensearch.ml.engine.httpclient.MLHttpClientFactory;
import org.opensearch.script.ScriptService;
import org.opensearch.transport.client.Client;

import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import okhttp3.internal.http2.StreamResetException;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import okhttp3.sse.EventSources;
import software.amazon.awssdk.core.internal.http.async.SimpleHttpContentPublisher;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.http.async.AsyncExecuteRequest;
Expand Down Expand Up @@ -63,13 +79,30 @@ public class AwsConnectorExecutor extends AbstractConnectorExecutor {

private SdkAsyncHttpClient httpClient;

@Setter
@Getter
private Supplier<StreamManager> streamManager;
private OkHttpClient okHttpClient;

public AwsConnectorExecutor(Connector connector) {
super.initialize(connector);
this.connector = (AwsConnector) connector;
Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout());
Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout());
Integer maxConnection = super.getConnectorClientConfig().getMaxConnections();
this.httpClient = MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection);
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
this.okHttpClient = new OkHttpClient.Builder()
.connectTimeout(10, TimeUnit.SECONDS)
.readTimeout(1, TimeUnit.MINUTES)
.retryOnConnectionFailure(true)
.build();
return null;
});
} catch (PrivilegedActionException e) {
throw new RuntimeException("Failed to build OkHttpClient.", e);
}
}

@Override
Expand Down Expand Up @@ -126,6 +159,41 @@ public void invokeRemoteService(
}
}

@Override
public void invokeRemoteServiceStream(
String action,
MLInput mlInput,
Map<String, String> parameters,
String payload,
ExecutionContext executionContext,
ActionListener<Tuple<Integer, ModelTensors>> actionListener
) {
try {
RemoteModelStreamProducer streamProducer = new RemoteModelStreamProducer();
StreamTicket streamTicket = streamManager.get().registerStream(streamProducer, null);
getLogger().info("[jngz] stream ticket: {}", streamTicket);
List<ModelTensor> modelTensors = new ArrayList<>();
modelTensors.add(ModelTensor.builder().name("response").dataAsMap(Map.of("stream_ticket", streamTicket)).build());
actionListener.onResponse(new Tuple<>(0, new ModelTensors(modelTensors)));
EventSourceListener listener = new AwsEventSourceListener(getLogger(), true, streamProducer);
Request request = ConnectorUtils.buildOKHttpRequestPOST(action, connector, parameters, payload);
getLogger().info("[jngz] Stream request: {}", request);
getLogger().info("[jngz] Stream request body: {}", request.body());
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
EventSources.createFactory(okHttpClient).newEventSource(request, listener);
return null;
});
} catch (PrivilegedActionException e) {
throw new RuntimeException("Failed to build event source.", e);
} catch (RuntimeException exception) {
log.error("[stream]Failed to execute {} in aws connector: {}", action, exception.getMessage(), exception);
actionListener.onFailure(exception);
} catch (Throwable e) {
log.error("[stream]Failed to execute {} in aws connector", action, e);
actionListener.onFailure(new MLException("Fail to execute " + action + " in aws connector", e));
}
}

private SdkHttpFullRequest signRequest(SdkHttpFullRequest request) {
String accessKey = connector.getAccessKey();
String secretKey = connector.getSecretKey();
Expand All @@ -135,4 +203,85 @@ private SdkHttpFullRequest signRequest(SdkHttpFullRequest request) {

return ConnectorUtils.signRequest(request, accessKey, secretKey, sessionToken, signingName, region);
}

public final class AwsEventSourceListener extends EventSourceListener {
private final Logger logger;
private final boolean debug;
private int publishErrorCount;
private RemoteModelStreamProducer streamProducer;

public AwsEventSourceListener(final Logger logger, final boolean debug, RemoteModelStreamProducer streamProducer) {
this.logger = logger;
this.debug = debug;
this.publishErrorCount = 0;
this.streamProducer = streamProducer;
}

/***
* Callback when the SSE endpoint connection is made.
* @param eventSource the event source
* @param response the response
*/
@Override
public void onOpen(EventSource eventSource, Response response) {
logger.info("[jngz] Connected to SSE Endpoint.");
}

/***
* For each event received from the SSE endpoint
* @param eventSource The event source
* @param id The id of the event
* @param type The type of the event which is used to filter
* @param data The event data
*/
@Override
public void onEvent(EventSource eventSource, String id, String type, String data) {
log.info("The event id is {} and the type is {}.", id, type);
switch (type) {
case "content_block_start":
log.info("Got the content start.");
break;
case "content_block_delta":
log.info("Content Delta: {}", data);
streamProducer.getQueue().offer(data);
break;
case "content_block_stop":
log.info("Got the content stop.");
streamProducer.getIsStop().set(true);
break;
default:
log.info("[No action] Got the content type: {}.", type);
break;
}
}

/***
* When the connection is closed we receive this even which is currently only logged.
* @param eventSource The event source
*/
@Override
public void onClosed(EventSource eventSource) {
logger.info("Closed");
}

/***
* If there is any failure we log the error and the stack trace
* During stream resets with no errors we set the connected flag to false to allow the main thread to attempt a re-connect
* @param eventSource The event source
* @param t The error object
* @param response The response
*/
@Override
public void onFailure(EventSource eventSource, Throwable t, Response response) {
if (t != null) {
logger.error("Error: " + t.getMessage(), t);
if (t instanceof StreamResetException && t.getMessage().contains("NO_ERROR")) {
// TODO: reconnect
} else {
streamProducer.setProduceError(true);
throw new MLException("SSE failure 2.", t);
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
import com.jayway.jsonpath.JsonPath;

import lombok.extern.log4j.Log4j2;
import okhttp3.MediaType;
import okhttp3.Request;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
Expand Down Expand Up @@ -330,6 +332,28 @@ public static SdkHttpFullRequest buildSdkRequest(
return builder.build();
}

public static Request buildOKHttpRequestPOST(String action, Connector connector, Map<String, String> parameters, String payload) {
String charset = parameters.getOrDefault("charset", "UTF-8");
okhttp3.RequestBody requestBody;
if (payload != null) {
requestBody = okhttp3.RequestBody.create(payload, MediaType.parse("application/json; charset=utf-8"));
} else {
throw new IllegalArgumentException("Content length is 0. Aborting request to remote model");
}

String endpoint = connector.getActionEndpoint(action, parameters);
Request.Builder requestBuilder = new Request.Builder();
Request request = requestBuilder
.url(endpoint)
.header("Accept-Encoding", "")
.header("Accept", "text/event-stream")
.header("Cache-Control", "no-cache")
.post(requestBody)
.build();

return request;
}

public static ConnectorAction createConnectorAction(Connector connector, ConnectorAction.ActionType actionType) {
Optional<ConnectorAction> batchPredictAction = connector.findAction(BATCH_PREDICT.name());
String predictEndpoint = batchPredictAction.get().getUrl();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,18 @@ public void invokeRemoteService(
}
}

@Override
public void invokeRemoteServiceStream(
String action,
MLInput mlInput,
Map<String, String> parameters,
String payload,
ExecutionContext executionContext,
ActionListener<Tuple<Integer, ModelTensors>> actionListener
) {
throw new RuntimeException("Unsupported stream remote service.");
}

private void validateHttpClientParameters(String action, Map<String, String> parameters) throws Exception {
String endpoint = connector.getActionEndpoint(action, parameters);
URL url = new URL(endpoint);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package org.opensearch.ml.engine.arrow;

import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;

import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.opensearch.arrow.spi.StreamProducer;
import org.opensearch.common.unit.TimeValue;

import lombok.Getter;
import lombok.extern.log4j.Log4j2;

@Log4j2
public class RemoteModelStreamProducer implements StreamProducer<VectorSchemaRoot, BufferAllocator> {
volatile boolean isClosed = false;
private final CountDownLatch closeLatch = new CountDownLatch(1);
TimeValue deadline = TimeValue.timeValueSeconds(5);
private volatile boolean produceError = false;

@Getter
private AtomicBoolean isStop = new AtomicBoolean(false);
@Getter
private ConcurrentLinkedQueue<String> queue = new ConcurrentLinkedQueue<>();

public void setProduceError(boolean produceError) {
this.produceError = produceError;
}

public RemoteModelStreamProducer() {}

VectorSchemaRoot root;

@Override
public VectorSchemaRoot createRoot(BufferAllocator allocator) {
VarCharVector eventVector = new VarCharVector("event", allocator);
FieldVector[] vectors = new FieldVector[] { eventVector };
root = new VectorSchemaRoot(Arrays.asList(vectors));
return root;
}

@Override
public BatchedJob<VectorSchemaRoot> createJob(BufferAllocator allocator) {
return new BatchedJob<>() {
@Override
public void run(VectorSchemaRoot root, FlushSignal flushSignal) {
VarCharVector eventVector = (VarCharVector) root.getVector("event");
root.setRowCount(1);
while (true) {
if (produceError) {
throw new RuntimeException("Server error while producing batch");
}
if (queue.isEmpty() && isStop.get()) {
log.info("The end of streaming response.");
return;
}
String event = queue.poll();
if (event != null) {
eventVector.setSafe(0, event.getBytes(StandardCharsets.UTF_8));
flushSignal.awaitConsumption(TimeValue.timeValueMillis(1000));
eventVector.clear();
root.setRowCount(1);
} else {
try {
Thread.sleep(500);
} catch (InterruptedException e) {
throw new RuntimeException("sleep failure");
}
}
}
}

@Override
public void onCancel() {
root.close();
isClosed = true;
}

@Override
public boolean isCancelled() {
return isClosed;
}
};
}

@Override
public TimeValue getJobDeadline() {
return deadline;
}

@Override
public int estimatedRowCount() {
return 100;
}

@Override
public String getAction() {
return "";
}

@Override
public void close() {
root.close();
closeLatch.countDown();
isClosed = true;
}

public boolean waitForClose(long timeout, TimeUnit unit) throws InterruptedException {
return closeLatch.await(timeout, unit);
}
}
Loading
Loading