Skip to content

Commit 2f80171

Browse files
authored
Implement genai events for bedrock (streaming) (#13507)
1 parent 06bd699 commit 2f80171

File tree

6 files changed

+1273
-45
lines changed

6 files changed

+1273
-45
lines changed

instrumentation/aws-sdk/aws-sdk-2.2/library/src/main/java/io/opentelemetry/instrumentation/awssdk/v2_2/AwsSdkTelemetry.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,6 @@ public SqsAsyncClient wrap(SqsAsyncClient sqsClient) {
152152
@NoMuzzle
153153
public BedrockRuntimeAsyncClient wrapBedrockRuntimeClient(
154154
BedrockRuntimeAsyncClient bedrockClient) {
155-
return BedrockRuntimeImpl.wrap(bedrockClient);
155+
return BedrockRuntimeImpl.wrap(bedrockClient, eventLogger, genAiCaptureMessageContent);
156156
}
157157
}

instrumentation/aws-sdk/aws-sdk-2.2/library/src/main/java/io/opentelemetry/instrumentation/awssdk/v2_2/internal/BedrockRuntimeImpl.java

+164-44
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,21 @@
2424
import java.util.HashMap;
2525
import java.util.List;
2626
import java.util.Map;
27+
import java.util.stream.Collectors;
2728
import javax.annotation.Nullable;
2829
import software.amazon.awssdk.core.SdkRequest;
2930
import software.amazon.awssdk.core.SdkResponse;
3031
import software.amazon.awssdk.core.async.SdkPublisher;
3132
import software.amazon.awssdk.core.document.Document;
3233
import software.amazon.awssdk.protocols.json.SdkJsonGenerator;
34+
import software.amazon.awssdk.protocols.jsoncore.JsonNode;
35+
import software.amazon.awssdk.protocols.jsoncore.JsonNodeParser;
3336
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
3437
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock;
38+
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDelta;
39+
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDeltaEvent;
40+
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStartEvent;
41+
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStopEvent;
3542
import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest;
3643
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
3744
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamMetadataEvent;
@@ -41,11 +48,13 @@
4148
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler;
4249
import software.amazon.awssdk.services.bedrockruntime.model.InferenceConfiguration;
4350
import software.amazon.awssdk.services.bedrockruntime.model.Message;
51+
import software.amazon.awssdk.services.bedrockruntime.model.MessageStartEvent;
4452
import software.amazon.awssdk.services.bedrockruntime.model.MessageStopEvent;
4553
import software.amazon.awssdk.services.bedrockruntime.model.StopReason;
4654
import software.amazon.awssdk.services.bedrockruntime.model.TokenUsage;
4755
import software.amazon.awssdk.services.bedrockruntime.model.ToolResultContentBlock;
4856
import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlock;
57+
import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlockStart;
4958
import software.amazon.awssdk.thirdparty.jackson.core.JsonFactory;
5059

5160
/**
@@ -59,6 +68,8 @@ private BedrockRuntimeImpl() {}
5968
private static final AttributeKey<String> GEN_AI_SYSTEM = stringKey("gen_ai.system");
6069

6170
private static final JsonFactory JSON_FACTORY = new JsonFactory();
71+
private static final JsonNodeParser JSON_PARSER = JsonNode.parser();
72+
private static final DocumentUnmarshaller DOCUMENT_UNMARSHALLER = new DocumentUnmarshaller();
6273

6374
static boolean isBedrockRuntimeRequest(SdkRequest request) {
6475
if (request instanceof ConverseRequest) {
@@ -202,35 +213,54 @@ static Long getUsageOutputTokens(Response response) {
202213
static void recordRequestEvents(
203214
Context otelContext, Logger eventLogger, SdkRequest request, boolean captureMessageContent) {
204215
if (request instanceof ConverseRequest) {
205-
for (Message message : ((ConverseRequest) request).messages()) {
206-
long numToolResults =
207-
message.content().stream().filter(block -> block.toolResult() != null).count();
208-
if (numToolResults > 0) {
209-
// Tool results are different from others, emitting multiple events for a single message,
210-
// so treat them separately.
211-
emitToolResultEvents(otelContext, eventLogger, message, captureMessageContent);
212-
if (numToolResults == message.content().size()) {
213-
continue;
214-
}
215-
// There are content blocks besides tool results in the same message. While models
216-
// generally don't expect such usage, the SDK allows it so go ahead and generate a normal
217-
// message too.
218-
}
219-
LogRecordBuilder event = newEvent(otelContext, eventLogger);
220-
switch (message.role()) {
221-
case ASSISTANT:
222-
event.setAttribute(EVENT_NAME, "gen_ai.assistant.message");
223-
break;
224-
case USER:
225-
event.setAttribute(EVENT_NAME, "gen_ai.user.message");
226-
break;
227-
default:
228-
// unknown role, shouldn't happen in practice
229-
continue;
216+
recordRequestMessageEvents(
217+
otelContext, eventLogger, ((ConverseRequest) request).messages(), captureMessageContent);
218+
}
219+
if (request instanceof ConverseStreamRequest) {
220+
recordRequestMessageEvents(
221+
otelContext,
222+
eventLogger,
223+
((ConverseStreamRequest) request).messages(),
224+
captureMessageContent);
225+
226+
// Good a time as any to store the context for a streaming request.
227+
TracingConverseStreamResponseHandler.fromContext(otelContext).setOtelContext(otelContext);
228+
}
229+
}
230+
231+
private static void recordRequestMessageEvents(
232+
Context otelContext,
233+
Logger eventLogger,
234+
List<Message> messages,
235+
boolean captureMessageContent) {
236+
for (Message message : messages) {
237+
long numToolResults =
238+
message.content().stream().filter(block -> block.toolResult() != null).count();
239+
if (numToolResults > 0) {
240+
// Tool results are different from others, emitting multiple events for a single message,
241+
// so treat them separately.
242+
emitToolResultEvents(otelContext, eventLogger, message, captureMessageContent);
243+
if (numToolResults == message.content().size()) {
244+
continue;
230245
}
231-
// Requests don't have index or stop reason.
232-
event.setBody(convertMessage(message, -1, null, captureMessageContent)).emit();
246+
// There are content blocks besides tool results in the same message. While models
247+
// generally don't expect such usage, the SDK allows it so go ahead and generate a normal
248+
// message too.
233249
}
250+
LogRecordBuilder event = newEvent(otelContext, eventLogger);
251+
switch (message.role()) {
252+
case ASSISTANT:
253+
event.setAttribute(EVENT_NAME, "gen_ai.assistant.message");
254+
break;
255+
case USER:
256+
event.setAttribute(EVENT_NAME, "gen_ai.user.message");
257+
break;
258+
default:
259+
// unknown role, shouldn't happen in practice
260+
continue;
261+
}
262+
// Requests don't have index or stop reason.
263+
event.setBody(convertMessage(message, -1, null, captureMessageContent)).emit();
234264
}
235265
}
236266

@@ -248,7 +278,7 @@ static void recordResponseEvents(
248278
convertMessage(
249279
converseResponse.output().message(),
250280
0,
251-
converseResponse.stopReason(),
281+
converseResponse.stopReasonAsString(),
252282
captureMessageContent))
253283
.emit();
254284
}
@@ -270,7 +300,8 @@ private static Double floatToDouble(Float value) {
270300
return Double.valueOf(value);
271301
}
272302

273-
public static BedrockRuntimeAsyncClient wrap(BedrockRuntimeAsyncClient asyncClient) {
303+
public static BedrockRuntimeAsyncClient wrap(
304+
BedrockRuntimeAsyncClient asyncClient, Logger eventLogger, boolean captureMessageContent) {
274305
// proxy BedrockRuntimeAsyncClient so we can wrap the subscriber to converseStream to capture
275306
// events.
276307
return (BedrockRuntimeAsyncClient)
@@ -283,7 +314,9 @@ public static BedrockRuntimeAsyncClient wrap(BedrockRuntimeAsyncClient asyncClie
283314
&& args[1] instanceof ConverseStreamResponseHandler) {
284315
TracingConverseStreamResponseHandler wrapped =
285316
new TracingConverseStreamResponseHandler(
286-
(ConverseStreamResponseHandler) args[1]);
317+
(ConverseStreamResponseHandler) args[1],
318+
eventLogger,
319+
captureMessageContent);
287320
args[1] = wrapped;
288321
try (Scope ignored = wrapped.makeCurrent()) {
289322
return invokeProxyMethod(method, asyncClient, args);
@@ -318,12 +351,29 @@ public static TracingConverseStreamResponseHandler fromContext(Context context)
318351
ContextKey.named("bedrock-runtime-converse-stream-response-handler");
319352

320353
private final ConverseStreamResponseHandler delegate;
354+
private final Logger eventLogger;
355+
private final boolean captureMessageContent;
356+
357+
private StringBuilder currentText;
358+
359+
// The response handler is created and stored into context before the span, so we need to
360+
// also pass the later context in for recording events. While subscribers are called from a
361+
// single thread, it is not clear if that is guaranteed to be the same as the execution
362+
// interceptor so we use volatile.
363+
private volatile Context otelContext;
364+
365+
private List<ToolUseBlock> tools;
366+
private ToolUseBlock.Builder currentTool;
367+
private StringBuilder currentToolArgs;
321368

322369
List<String> stopReasons;
323370
TokenUsage usage;
324371

325-
TracingConverseStreamResponseHandler(ConverseStreamResponseHandler delegate) {
372+
TracingConverseStreamResponseHandler(
373+
ConverseStreamResponseHandler delegate, Logger eventLogger, boolean captureMessageContent) {
326374
this.delegate = delegate;
375+
this.eventLogger = eventLogger;
376+
this.captureMessageContent = captureMessageContent;
327377
}
328378

329379
@Override
@@ -336,19 +386,66 @@ public void onEventStream(SdkPublisher<ConverseStreamOutput> sdkPublisher) {
336386
delegate.onEventStream(
337387
sdkPublisher.map(
338388
event -> {
339-
if (event instanceof MessageStopEvent) {
340-
if (stopReasons == null) {
341-
stopReasons = new ArrayList<>();
342-
}
343-
stopReasons.add(((MessageStopEvent) event).stopReasonAsString());
344-
}
345-
if (event instanceof ConverseStreamMetadataEvent) {
346-
usage = ((ConverseStreamMetadataEvent) event).usage();
347-
}
389+
handleEvent(event);
348390
return event;
349391
}));
350392
}
351393

394+
private void handleEvent(ConverseStreamOutput event) {
395+
if (captureMessageContent && event instanceof MessageStartEvent) {
396+
if (currentText == null) {
397+
currentText = new StringBuilder();
398+
}
399+
currentText.setLength(0);
400+
}
401+
if (event instanceof ContentBlockStartEvent) {
402+
ToolUseBlockStart toolUse = ((ContentBlockStartEvent) event).start().toolUse();
403+
if (toolUse != null) {
404+
if (currentToolArgs == null) {
405+
currentToolArgs = new StringBuilder();
406+
}
407+
currentToolArgs.setLength(0);
408+
currentTool = ToolUseBlock.builder().name(toolUse.name()).toolUseId(toolUse.toolUseId());
409+
}
410+
}
411+
if (event instanceof ContentBlockDeltaEvent) {
412+
ContentBlockDelta delta = ((ContentBlockDeltaEvent) event).delta();
413+
if (captureMessageContent && delta.text() != null) {
414+
currentText.append(delta.text());
415+
}
416+
if (delta.toolUse() != null) {
417+
currentToolArgs.append(delta.toolUse().input());
418+
}
419+
}
420+
if (event instanceof ContentBlockStopEvent) {
421+
if (currentTool != null) {
422+
if (tools == null) {
423+
tools = new ArrayList<>();
424+
}
425+
if (currentToolArgs != null) {
426+
Document args = deserializeDocument(currentToolArgs.toString());
427+
currentTool.input(args);
428+
}
429+
tools.add(currentTool.build());
430+
currentTool = null;
431+
}
432+
}
433+
if (event instanceof MessageStopEvent) {
434+
if (stopReasons == null) {
435+
stopReasons = new ArrayList<>();
436+
}
437+
String stopReason = ((MessageStopEvent) event).stopReasonAsString();
438+
stopReasons.add(stopReason);
439+
newEvent(otelContext, eventLogger)
440+
.setAttribute(EVENT_NAME, "gen_ai.choice")
441+
.setBody(convertMessageData(currentText, tools, 0, stopReason, captureMessageContent))
442+
.emit();
443+
}
444+
if (event instanceof ConverseStreamMetadataEvent) {
445+
usage = ((ConverseStreamMetadataEvent) event).usage();
446+
}
447+
}
448+
352449
@Override
353450
public void exceptionOccurred(Throwable throwable) {
354451
delegate.exceptionOccurred(throwable);
@@ -363,6 +460,10 @@ public void complete() {
363460
public Context storeInContext(Context context) {
364461
return context.with(KEY, this);
365462
}
463+
464+
void setOtelContext(Context otelContext) {
465+
this.otelContext = otelContext;
466+
}
366467
}
367468

368469
private static LogRecordBuilder newEvent(Context otelContext, Logger eventLogger) {
@@ -401,9 +502,9 @@ private static void emitToolResultEvents(
401502
}
402503

403504
private static Value<?> convertMessage(
404-
Message message, int index, @Nullable StopReason stopReason, boolean captureMessageContent) {
505+
Message message, int index, @Nullable String stopReason, boolean captureMessageContent) {
405506
StringBuilder text = null;
406-
List<Value<?>> toolCalls = null;
507+
List<ToolUseBlock> toolCalls = null;
407508
for (ContentBlock content : message.content()) {
408509
if (captureMessageContent && content.text() != null) {
409510
if (text == null) {
@@ -415,15 +516,29 @@ private static Value<?> convertMessage(
415516
if (toolCalls == null) {
416517
toolCalls = new ArrayList<>();
417518
}
418-
toolCalls.add(convertToolCall(content.toolUse(), captureMessageContent));
519+
toolCalls.add(content.toolUse());
419520
}
420521
}
522+
523+
return convertMessageData(text, toolCalls, index, stopReason, captureMessageContent);
524+
}
525+
526+
private static Value<?> convertMessageData(
527+
@Nullable StringBuilder text,
528+
List<ToolUseBlock> toolCalls,
529+
int index,
530+
@Nullable String stopReason,
531+
boolean captureMessageContent) {
421532
Map<String, Value<?>> body = new HashMap<>();
422533
if (text != null) {
423534
body.put("content", Value.of(text.toString()));
424535
}
425536
if (toolCalls != null) {
426-
body.put("toolCalls", Value.of(toolCalls));
537+
List<Value<?>> toolCallValues =
538+
toolCalls.stream()
539+
.map(tool -> convertToolCall(tool, captureMessageContent))
540+
.collect(Collectors.toList());
541+
body.put("toolCalls", Value.of(toolCallValues));
427542
}
428543
if (stopReason != null) {
429544
body.put("finish_reason", Value.of(stopReason.toString()));
@@ -451,4 +566,9 @@ private static String serializeDocument(Document document) {
451566
document.accept(marshaller);
452567
return new String(generator.getBytes(), StandardCharsets.UTF_8);
453568
}
569+
570+
private static Document deserializeDocument(String json) {
571+
JsonNode node = JSON_PARSER.parse(json);
572+
return node.visit(DOCUMENT_UNMARSHALLER);
573+
}
454574
}

0 commit comments

Comments
 (0)